@@ -945,7 +945,6 @@ def _impl_v9(cls, inputs, attr, params):
945
945
return out
946
946
947
947
948
-
949
948
class Shape (OnnxOpConverter ):
950
949
"""Operator converter for Shape."""
951
950
@@ -1047,24 +1046,35 @@ def _impl_v1(cls, inputs, attr, params):
1047
1046
1048
1047
@classmethod
1049
1048
def _impl_v10 (cls , inputs , attr , params ):
1050
- attrs = {"starts" : inputs [1 ], "ends" : inputs [2 ]}
1051
- if len (inputs ) >= 4 :
1052
- attrs ["axes" ] = inputs [3 ]
1053
- attrs = {k : (v , get_name (v )) for (k , v ) in attrs .items ()}
1054
- attrs = {
1055
- k : params [v [1 ]].asnumpy ()
1056
- if v [1 ] in params
1057
- else infer_value_simulated (v [0 ], params ).asnumpy ()
1058
- for (k , v ) in attrs .items ()
1059
- }
1049
+ starts = inputs [1 ]
1050
+ ends = inputs [2 ]
1051
+ axes = inputs [3 ]
1052
+ steps = inputs [4 ]
1053
+
1054
+ data_rank = len (infer_shape (inputs [0 ]))
1060
1055
1061
1056
# Update the starts and ends according to axes if required.
1062
- if "axes" in attrs :
1063
- if max (attrs ["axes" ] + 1 ) != len (attrs ["axes" ]):
1064
- new_starts , new_ends , _ = cls ._common (attrs ["starts" ], attrs ["ends" ], attrs ["axes" ])
1065
- attrs ["starts" ] = new_starts
1066
- attrs ["ends" ] = new_ends
1067
- return _op .strided_slice (inputs [0 ], begin = list (attrs ["starts" ]), end = list (attrs ["ends" ]))
1057
+ if axes is not None :
1058
+ data_shape = _op .shape_of (inputs [0 ], dtype = infer_type (ends ).checked_type .dtype )
1059
+ starts = _op .scatter (
1060
+ _op .const ([0 ] * data_rank , dtype = infer_type (starts ).checked_type .dtype ),
1061
+ axes ,
1062
+ starts ,
1063
+ axis = 0 ,
1064
+ )
1065
+ ends = _op .scatter (data_shape , axes , ends , axis = 0 )
1066
+ if steps is not None :
1067
+ steps = _op .scatter (
1068
+ _op .const ([1 ] * data_rank , dtype = infer_type (steps ).checked_type .dtype ),
1069
+ axes ,
1070
+ steps ,
1071
+ axis = 0 ,
1072
+ )
1073
+
1074
+ if steps is None :
1075
+ steps = _op .const ([1 ] * data_rank , dtype = infer_type (starts ).checked_type .dtype )
1076
+
1077
+ return _op .strided_slice (inputs [0 ], starts , ends , steps )
1068
1078
1069
1079
1070
1080
class Gather (OnnxOpConverter ):
@@ -1406,7 +1416,6 @@ def _impl_v6(cls, inputs, attr, params):
1406
1416
return _op .tile (inputs [0 ], inputs [1 ])
1407
1417
1408
1418
1409
-
1410
1419
class Erf (OnnxOpConverter ):
1411
1420
"""Operator converter for Erf"""
1412
1421
0 commit comments