@@ -125,7 +125,7 @@ def get_converter(cls, opset):
125125 return getattr (cls , "_impl_v{}" .format (version ))
126126 raise NotImplementedError (
127127 "opset version {} of {} not implemented" .format (version , cls .__name__ )
128- )
128+ )
129129
130130
131131class MatMul (OnnxOpConverter ):
@@ -135,41 +135,50 @@ class MatMul(OnnxOpConverter):
135135 def _impl_v13 (cls , bb , inputs , attr ):
136136 return bb .emit_te (topi .matmul , inputs [0 ], inputs [1 ])
137137
138+
138139class Div (OnnxOpConverter ):
139140 """Converts an onnx Div node into an equivalent Relax expression."""
141+
140142 @classmethod
141143 def _impl_v14 (cls , bb , inputs , attr ):
142144 return bb .emit_te (topi .divide , inputs [0 ], inputs [1 ])
143145
146+
144147class Sigmoid (OnnxOpConverter ):
145148 """Converts an onnx Sigmoid node into an equivalent Relax expression."""
149+
146150 @classmethod
147151 def _impl_v13 (cls , bb , inputs , attr ):
148152 return bb .emit_te (topi .sigmoid , inputs [0 ])
149153
150154
151155class Softmax (OnnxOpConverter ):
152156 """Converts an onnx Softmax node into an equivalent Relax expression."""
157+
153158 @classmethod
154159 def _impl_v13 (cls , bb , inputs , attr ):
155160 axis = attr .get ("axis" , - 1 )
156161 return bb .emit_te (topi .nn .softmax , inputs [0 ], axis = axis )
157162
163+
158164class Transpose (OnnxOpConverter ):
159165 """Converts an onnx Transpose node into an equivalent Relax expression."""
166+
160167 @classmethod
161168 def _impl_v13 (cls , bb , inputs , attr ):
162169 perm = attr .get ("perm" , None )
163170 return bb .emit_te (topi .transpose , inputs [0 ], axes = perm )
164171
172+
165173class Unsqueeze (OnnxOpConverter ):
166174 """Converts an onnx Unsqueeze node into an equivalent Relax expression."""
175+
167176 @classmethod
168177 def _impl_v13 (cls , bb , inputs , attr ):
169178 input = inputs [0 ]
170179 axes = inputs [1 ]
171180
172- if ( isinstance (axes , relax .Constant ) ):
181+ if isinstance (axes , relax .Constant ):
173182 constant_axes = list (axes .data .numpy ())
174183 constant_axes = list (map (int , constant_axes ))
175184 constant_axes = sorted (constant_axes )
@@ -179,6 +188,7 @@ def _impl_v13(cls, bb, inputs, attr):
179188
180189 raise NotImplementedError ("Unsqueeze with dynamic axes is not supported." )
181190
191+
182192class Concat (OnnxOpConverter ):
183193 """Convert an onnx Concat node into an equivalent Relax expression."""
184194
@@ -207,6 +217,8 @@ def _impl_v13(cls, bb, inputs, attr):
207217class Cast (OnnxOpConverter ):
208218 """Convert an onnx Cast node into an equivalent Relax expression."""
209219
220+ """Convert an onnx Cast node into an equivalent Relax expression."""
221+
210222 @classmethod
211223 def _impl_v13 (cls , bb , inputs , attr ):
212224 to_type = get_type (attr ["to" ])
@@ -216,6 +228,8 @@ def _impl_v13(cls, bb, inputs, attr):
216228class Gather (OnnxOpConverter ):
217229 """Convert an onnx Gather node into an equivalent Relax expression."""
218230
231+ """Convert an onnx Gather node into an equivalent Relax expression."""
232+
219233 @classmethod
220234 def _impl_v13 (cls , bb , inputs , attr ):
221235 # TODO This assumes positive only indices.
@@ -255,16 +269,20 @@ def _impl_v13(cls, bb, inputs, attr):
255269class Reshape (OnnxOpConverter ):
256270 """Convert an onnx Reshape node into an equivalent Relax expression."""
257271
272+ """Convert an onnx Reshape node into an equivalent Relax expression."""
273+
258274 @classmethod
259275 def _impl_v13 (cls , bb , inputs , attr ):
260276 from tvm .script import relax as R
277+
261278 data = inputs [0 ]
262279 # TODO We assume new_shape is a constant, need to enable tensor input to reshape
263280 # for full support.
264281 new_shape = inputs [1 ].data .numpy ()
265282
266283 # Convert -1 dims in new_shape into positive equivalent.
267284 if - 1 in new_shape :
285+ breakpoint ()
268286 data_shape = [dim .value for dim in data .shape .values ]
269287 total_elements = np .prod (data_shape )
270288 new_product = 1
@@ -277,14 +295,15 @@ def _impl_v13(cls, bb, inputs, attr):
277295 if dim == - 1 :
278296 new_shape [i ] = int (total_elements / new_product )
279297
280-
281298 return bb .emit_te (topi .reshape , data , new_shape )
282299
300+
283301class Gelu (OnnxOpConverter ):
284302 """Operator converter for Gelu from Microsoft onnxruntime contrib opset.
285303
286304 gelu(x) = 0.5x(1 + erf(x/sqrt(2)))
287305 """
306+
288307 @classmethod
289308 def _impl_v1 (cls , bb , inputs , attr ):
290309 x = inputs [0 ]
@@ -297,15 +316,17 @@ def _impl_v1(cls, bb, inputs, attr):
297316
298317 # Compute gelu
299318 term1 = bb .emit_te (topi .multiply , half , x )
300- erf = bb .emit_te (topi .erf , bb .emit_te (topi .divide , x , sqrt2 ))
319+ erf = bb .emit_te (topi .erf , bb .emit_te (topi .divide , x , sqrt2 ))
301320 term2 = bb .emit_te (topi .add , one , erf )
302321 return bb .emit_te (topi .multiply , term1 , term2 )
303322
323+
304324class BiasGelu (OnnxOpConverter ):
305325 """Operator converter for BiasGelu from Microsoft onnxruntime contrib opset.
306326
307327 bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2)))
308328 """
329+
309330 @classmethod
310331 def _impl_v1 (cls , bb , inputs , attr ):
311332 x = inputs [0 ]
@@ -317,12 +338,134 @@ def _impl_v1(cls, bb, inputs, attr):
317338 inp = bb .emit_te (topi .add , x , b )
318339 return Gelu ._impl_v1 (bb , [inp ], attr )
319340
341+
320342class Where (OnnxOpConverter ):
321343 """Convert an onnx Where node into an equivalent Relax expression."""
344+
322345 @classmethod
323346 def _impl_v16 (cls , bb , inputs , attr ):
324347 return bb .emit_te (topi .where , * inputs )
325348
349+
350+ class Clip (OnnxOpConverter ):
351+ """Converts an onnx Clip node into an equivalent Relax expression."""
352+
353+ @classmethod
354+ def _impl_v13 (cls , bb , inputs , attr ):
355+ results = inputs [0 ]
356+ if len (inputs ) >= 2 :
357+ results = bb .emit_te (topi .maximum , results , inputs [1 ])
358+ if len (inputs ) >= 3 :
359+ results = bb .emit_te (topi .minimum , results , inputs [2 ])
360+ return results
361+
362+
363+ class Equal (OnnxOpConverter ):
364+ """Converts an onnx Equal node into an equivalent Relax expression."""
365+
366+ @classmethod
367+ def _impl_v13 (cls , bb , inputs , attr ):
368+ return bb .emit_te (topi .equal , inputs [0 ], inputs [1 ])
369+
370+
371+ class Shape (OnnxOpConverter ):
372+ """Converts an onnx Equal node into an equivalent Relax expression."""
373+
374+ @classmethod
375+ def _impl_v13 (cls , bb , inputs , attr ):
376+ return bb .emit_te (topi .shape , inputs [0 ], inputs [1 ])
377+
378+
379+ class Not (OnnxOpConverter ):
380+ """Converts an onnx Not node into an equivalent Relax expression."""
381+
382+ @classmethod
383+ def _impl_v13 (cls , bb , inputs , attr ):
384+ return bb .emit_te (topi .bitwise_not , inputs [0 ])
385+
386+
387+ class Tanh (OnnxOpConverter ):
388+ """Converts an onnx Tanh node into an equivalent Relax expression."""
389+
390+ @classmethod
391+ def _impl_v13 (cls , bb , inputs , attr ):
392+ return bb .emit_te (topi .tanh , inputs [0 ])
393+
394+
395+ class Sqrt (OnnxOpConverter ):
396+ """Converts an onnx Sqrt node into an equivalent Relax expression."""
397+
398+ @classmethod
399+ def _impl_v13 (cls , bb , inputs , attr ):
400+ return bb .emit_te (topi .sqrt , inputs [0 ])
401+
402+
403+ class Relu (OnnxOpConverter ):
404+ """Converts an onnx Relu node into an equivalent Relax expression."""
405+
406+ @classmethod
407+ def _impl_v13 (cls , bb , inputs , attr ):
408+ return bb .emit_te (topi .nn .relu , inputs [0 ])
409+
410+
411+ class Pow (OnnxOpConverter ):
412+ """Converts an onnx Pow node into an equivalent Relax expression."""
413+
414+ @classmethod
415+ def _impl_v13 (cls , bb , inputs , attr ):
416+ return bb .emit_te (topi .power , inputs [0 ], inputs [1 ])
417+
418+
419+ class Conv (OnnxOpConverter ):
420+ """Convert an onnx Conv node into an equivalent Relax expression."""
421+
422+ @classmethod
423+ def _impl_v13 (cls , bb , inputs , attr ):
424+ # not supported yet
425+ assert "auto_pad" not in attr
426+ assert "group" not in attr
427+ # supported conv2d
428+ return bb .emit_te (
429+ topi .add ,
430+ bb .emit_te (
431+ topi .nn .conv2d ,
432+ inputs [0 ],
433+ inputs [1 ],
434+ strides = attr .get ("strides" , 1 ),
435+ padding = attr .get ("pads" , 0 ),
436+ dilation = attr .get ("dilations" , 1 ),
437+ ),
438+ bb .emit_te (topi .expand_dims , inputs [2 ], axis = 1 , num_newaxis = 2 ),
439+ )
440+
441+
442+ class Erf (OnnxOpConverter ):
443+ """Converts an onnx Erf node into an equivalent Relax expression."""
444+
445+ @classmethod
446+ def _impl_v13 (cls , bb , inputs , attr ):
447+ return bb .emit_te (topi .erf , inputs [0 ])
448+
449+
450+ class CumSum (OnnxOpConverter ):
451+ """Converts an onnx CumSum node into an equivalent Relax expression."""
452+
453+ @classmethod
454+ def _impl_v13 (cls , bb , inputs , attr ):
455+ assert getattr (attr , "reverse" , 0 ) == 0 , "reverse is not supported yet"
456+ if len (inputs ) > 1 :
457+ # axis = int(infer_value(inputs[1], params).numpy())
458+ axis = inputs [1 ]
459+ else :
460+ axis = None
461+ return bb .emit_te (
462+ topi .cumsum ,
463+ data = inputs [0 ],
464+ axis = axis ,
465+ exclusive = attr .get ("exclusive" , None ),
466+ )
467+
468+
326469def _get_convert_map (opset ):
327470 return {
328471 "MatMul" : MatMul .get_converter (opset ),
@@ -341,6 +484,17 @@ def _get_convert_map(opset):
341484 "Gelu" : Gelu .get_converter (opset ),
342485 "BiasGelu" : BiasGelu .get_converter (opset ),
343486 "Where" : Where .get_converter (opset ),
487+ "Clip" : Clip .get_converter (opset ),
488+ "Equal" : Equal .get_converter (opset ),
489+ "Shape" : Shape .get_converter (opset ),
490+ "Not" : Not .get_converter (opset ),
491+ "Tanh" : Tanh .get_converter (opset ),
492+ "Sqrt" : Sqrt .get_converter (opset ),
493+ "Relu" : Relu .get_converter (opset ),
494+ "Conv" : Conv .get_converter (opset ),
495+ "Pow" : Pow .get_converter (opset ),
496+ "Erf" : Erf .get_converter (opset ),
497+ "CumSum" : CumSum .get_converter (opset ),
344498 }
345499
346500
@@ -630,4 +784,4 @@ def from_onnx(model, shape=None, dtype="float32", opset=None):
630784 )
631785
632786 # Use the graph proto as a scope so that ops can access other nodes if needed.
633- return g .from_onnx (graph , opset )
787+ return g .from_onnx (graph , opset )
0 commit comments