Skip to content

Commit b8f2a3a

Browse files
zxybazhJosh Fromm
authored andcommitted
[Relax][ONNX] Add Multiple ONNX Frontend Support for Clip / Equal / Shape / Not / Tanh (#3)
* Rebase w/ Equal, Not, Tanh, Sqrt, Relu, Clip, Conv, Pow, Erf. * Fix cumsum but still needs work.
1 parent 6cdf62f commit b8f2a3a

File tree

2 files changed

+382
-18
lines changed

2 files changed

+382
-18
lines changed

python/tvm/relax/frontend/onnx_frontend.py

Lines changed: 159 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_converter(cls, opset):
125125
return getattr(cls, "_impl_v{}".format(version))
126126
raise NotImplementedError(
127127
"opset version {} of {} not implemented".format(version, cls.__name__)
128-
)
128+
)
129129

130130

131131
class MatMul(OnnxOpConverter):
@@ -135,41 +135,50 @@ class MatMul(OnnxOpConverter):
135135
def _impl_v13(cls, bb, inputs, attr):
136136
return bb.emit_te(topi.matmul, inputs[0], inputs[1])
137137

138+
138139
class Div(OnnxOpConverter):
139140
"""Converts an onnx Div node into an equivalent Relax expression."""
141+
140142
@classmethod
141143
def _impl_v14(cls, bb, inputs, attr):
142144
return bb.emit_te(topi.divide, inputs[0], inputs[1])
143145

146+
144147
class Sigmoid(OnnxOpConverter):
145148
"""Converts an onnx Sigmoid node into an equivalent Relax expression."""
149+
146150
@classmethod
147151
def _impl_v13(cls, bb, inputs, attr):
148152
return bb.emit_te(topi.sigmoid, inputs[0])
149153

150154

151155
class Softmax(OnnxOpConverter):
152156
"""Converts an onnx Softmax node into an equivalent Relax expression."""
157+
153158
@classmethod
154159
def _impl_v13(cls, bb, inputs, attr):
155160
axis = attr.get("axis", -1)
156161
return bb.emit_te(topi.nn.softmax, inputs[0], axis=axis)
157162

163+
158164
class Transpose(OnnxOpConverter):
159165
"""Converts an onnx Transpose node into an equivalent Relax expression."""
166+
160167
@classmethod
161168
def _impl_v13(cls, bb, inputs, attr):
162169
perm = attr.get("perm", None)
163170
return bb.emit_te(topi.transpose, inputs[0], axes=perm)
164171

172+
165173
class Unsqueeze(OnnxOpConverter):
166174
"""Converts an onnx Unsqueeze node into an equivalent Relax expression."""
175+
167176
@classmethod
168177
def _impl_v13(cls, bb, inputs, attr):
169178
input = inputs[0]
170179
axes = inputs[1]
171180

172-
if (isinstance(axes, relax.Constant)):
181+
if isinstance(axes, relax.Constant):
173182
constant_axes = list(axes.data.numpy())
174183
constant_axes = list(map(int, constant_axes))
175184
constant_axes = sorted(constant_axes)
@@ -179,6 +188,7 @@ def _impl_v13(cls, bb, inputs, attr):
179188

180189
raise NotImplementedError("Unsqueeze with dynamic axes is not supported.")
181190

191+
182192
class Concat(OnnxOpConverter):
183193
"""Convert an onnx Concat node into an equivalent Relax expression."""
184194

@@ -207,6 +217,8 @@ def _impl_v13(cls, bb, inputs, attr):
207217
class Cast(OnnxOpConverter):
208218
"""Convert an onnx Cast node into an equivalent Relax expression."""
209219

220+
"""Convert an onnx Cast node into an equivalent Relax expression."""
221+
210222
@classmethod
211223
def _impl_v13(cls, bb, inputs, attr):
212224
to_type = get_type(attr["to"])
@@ -216,6 +228,8 @@ def _impl_v13(cls, bb, inputs, attr):
216228
class Gather(OnnxOpConverter):
217229
"""Convert an onnx Gather node into an equivalent Relax expression."""
218230

231+
"""Convert an onnx Gather node into an equivalent Relax expression."""
232+
219233
@classmethod
220234
def _impl_v13(cls, bb, inputs, attr):
221235
# TODO This assumes positive only indices.
@@ -255,16 +269,20 @@ def _impl_v13(cls, bb, inputs, attr):
255269
class Reshape(OnnxOpConverter):
256270
"""Convert an onnx Reshape node into an equivalent Relax expression."""
257271

272+
"""Convert an onnx Reshape node into an equivalent Relax expression."""
273+
258274
@classmethod
259275
def _impl_v13(cls, bb, inputs, attr):
260276
from tvm.script import relax as R
277+
261278
data = inputs[0]
262279
# TODO We assume new_shape is a constant, need to enable tensor input to reshape
263280
# for full support.
264281
new_shape = inputs[1].data.numpy()
265282

266283
# Convert -1 dims in new_shape into positive equivalent.
267284
if -1 in new_shape:
285+
breakpoint()
268286
data_shape = [dim.value for dim in data.shape.values]
269287
total_elements = np.prod(data_shape)
270288
new_product = 1
@@ -277,14 +295,15 @@ def _impl_v13(cls, bb, inputs, attr):
277295
if dim == -1:
278296
new_shape[i] = int(total_elements / new_product)
279297

280-
281298
return bb.emit_te(topi.reshape, data, new_shape)
282299

300+
283301
class Gelu(OnnxOpConverter):
284302
"""Operator converter for Gelu from Microsoft onnxruntime contrib opset.
285303
286304
gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
287305
"""
306+
288307
@classmethod
289308
def _impl_v1(cls, bb, inputs, attr):
290309
x = inputs[0]
@@ -297,15 +316,17 @@ def _impl_v1(cls, bb, inputs, attr):
297316

298317
# Compute gelu
299318
term1 = bb.emit_te(topi.multiply, half, x)
300-
erf = bb.emit_te(topi.erf, bb.emit_te(topi.divide, x, sqrt2))
319+
erf = bb.emit_te(topi.erf, bb.emit_te(topi.divide, x, sqrt2))
301320
term2 = bb.emit_te(topi.add, one, erf)
302321
return bb.emit_te(topi.multiply, term1, term2)
303322

323+
304324
class BiasGelu(OnnxOpConverter):
305325
"""Operator converter for BiasGelu from Microsoft onnxruntime contrib opset.
306326
307327
bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
308328
"""
329+
309330
@classmethod
310331
def _impl_v1(cls, bb, inputs, attr):
311332
x = inputs[0]
@@ -317,12 +338,134 @@ def _impl_v1(cls, bb, inputs, attr):
317338
inp = bb.emit_te(topi.add, x, b)
318339
return Gelu._impl_v1(bb, [inp], attr)
319340

341+
320342
class Where(OnnxOpConverter):
321343
"""Convert an onnx Where node into an equivalent Relax expression."""
344+
322345
@classmethod
323346
def _impl_v16(cls, bb, inputs, attr):
324347
return bb.emit_te(topi.where, *inputs)
325348

349+
350+
class Clip(OnnxOpConverter):
351+
"""Converts an onnx Clip node into an equivalent Relax expression."""
352+
353+
@classmethod
354+
def _impl_v13(cls, bb, inputs, attr):
355+
results = inputs[0]
356+
if len(inputs) >= 2:
357+
results = bb.emit_te(topi.maximum, results, inputs[1])
358+
if len(inputs) >= 3:
359+
results = bb.emit_te(topi.minimum, results, inputs[2])
360+
return results
361+
362+
363+
class Equal(OnnxOpConverter):
364+
"""Converts an onnx Equal node into an equivalent Relax expression."""
365+
366+
@classmethod
367+
def _impl_v13(cls, bb, inputs, attr):
368+
return bb.emit_te(topi.equal, inputs[0], inputs[1])
369+
370+
371+
class Shape(OnnxOpConverter):
372+
"""Converts an onnx Equal node into an equivalent Relax expression."""
373+
374+
@classmethod
375+
def _impl_v13(cls, bb, inputs, attr):
376+
return bb.emit_te(topi.shape, inputs[0], inputs[1])
377+
378+
379+
class Not(OnnxOpConverter):
380+
"""Converts an onnx Not node into an equivalent Relax expression."""
381+
382+
@classmethod
383+
def _impl_v13(cls, bb, inputs, attr):
384+
return bb.emit_te(topi.bitwise_not, inputs[0])
385+
386+
387+
class Tanh(OnnxOpConverter):
388+
"""Converts an onnx Tanh node into an equivalent Relax expression."""
389+
390+
@classmethod
391+
def _impl_v13(cls, bb, inputs, attr):
392+
return bb.emit_te(topi.tanh, inputs[0])
393+
394+
395+
class Sqrt(OnnxOpConverter):
396+
"""Converts an onnx Sqrt node into an equivalent Relax expression."""
397+
398+
@classmethod
399+
def _impl_v13(cls, bb, inputs, attr):
400+
return bb.emit_te(topi.sqrt, inputs[0])
401+
402+
403+
class Relu(OnnxOpConverter):
404+
"""Converts an onnx Relu node into an equivalent Relax expression."""
405+
406+
@classmethod
407+
def _impl_v13(cls, bb, inputs, attr):
408+
return bb.emit_te(topi.nn.relu, inputs[0])
409+
410+
411+
class Pow(OnnxOpConverter):
412+
"""Converts an onnx Pow node into an equivalent Relax expression."""
413+
414+
@classmethod
415+
def _impl_v13(cls, bb, inputs, attr):
416+
return bb.emit_te(topi.power, inputs[0], inputs[1])
417+
418+
419+
class Conv(OnnxOpConverter):
420+
"""Convert an onnx Conv node into an equivalent Relax expression."""
421+
422+
@classmethod
423+
def _impl_v13(cls, bb, inputs, attr):
424+
# not supported yet
425+
assert "auto_pad" not in attr
426+
assert "group" not in attr
427+
# supported conv2d
428+
return bb.emit_te(
429+
topi.add,
430+
bb.emit_te(
431+
topi.nn.conv2d,
432+
inputs[0],
433+
inputs[1],
434+
strides=attr.get("strides", 1),
435+
padding=attr.get("pads", 0),
436+
dilation=attr.get("dilations", 1),
437+
),
438+
bb.emit_te(topi.expand_dims, inputs[2], axis=1, num_newaxis=2),
439+
)
440+
441+
442+
class Erf(OnnxOpConverter):
443+
"""Converts an onnx Erf node into an equivalent Relax expression."""
444+
445+
@classmethod
446+
def _impl_v13(cls, bb, inputs, attr):
447+
return bb.emit_te(topi.erf, inputs[0])
448+
449+
450+
class CumSum(OnnxOpConverter):
451+
"""Converts an onnx CumSum node into an equivalent Relax expression."""
452+
453+
@classmethod
454+
def _impl_v13(cls, bb, inputs, attr):
455+
assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet"
456+
if len(inputs) > 1:
457+
# axis = int(infer_value(inputs[1], params).numpy())
458+
axis = inputs[1]
459+
else:
460+
axis = None
461+
return bb.emit_te(
462+
topi.cumsum,
463+
data=inputs[0],
464+
axis=axis,
465+
exclusive=attr.get("exclusive", None),
466+
)
467+
468+
326469
def _get_convert_map(opset):
327470
return {
328471
"MatMul": MatMul.get_converter(opset),
@@ -341,6 +484,17 @@ def _get_convert_map(opset):
341484
"Gelu": Gelu.get_converter(opset),
342485
"BiasGelu": BiasGelu.get_converter(opset),
343486
"Where": Where.get_converter(opset),
487+
"Clip": Clip.get_converter(opset),
488+
"Equal": Equal.get_converter(opset),
489+
"Shape": Shape.get_converter(opset),
490+
"Not": Not.get_converter(opset),
491+
"Tanh": Tanh.get_converter(opset),
492+
"Sqrt": Sqrt.get_converter(opset),
493+
"Relu": Relu.get_converter(opset),
494+
"Conv": Conv.get_converter(opset),
495+
"Pow": Pow.get_converter(opset),
496+
"Erf": Erf.get_converter(opset),
497+
"CumSum": CumSum.get_converter(opset),
344498
}
345499

346500

@@ -630,4 +784,4 @@ def from_onnx(model, shape=None, dtype="float32", opset=None):
630784
)
631785

632786
# Use the graph proto as a scope so that ops can access other nodes if needed.
633-
return g.from_onnx(graph, opset)
787+
return g.from_onnx(graph, opset)

0 commit comments

Comments
 (0)