@@ -47,10 +47,10 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
47
47
# input sequence should be "data", "starts", "ends", "axes", "steps"
48
48
attr = {}
49
49
data = self .convert_to_input (kwargs .pop ("data" ))
50
- starts = self .convert_to_input (kwargs .pop ("starts" ))
51
- ends = self .convert_to_input (kwargs .pop ("ends" ))
52
- axes = self .convert_to_input (kwargs .pop ("axes" , None ), is_optional = True )
53
- steps = self .convert_to_input (kwargs .pop ("steps" , None ), is_optional = True )
50
+ starts = self .convert_to_input (kwargs .pop ("starts" ), dtype = np . int64 )
51
+ ends = self .convert_to_input (kwargs .pop ("ends" ), dtype = np . int64 )
52
+ axes = self .convert_to_input (kwargs .pop ("axes" , None ), is_optional = True , dtype = np . int64 )
53
+ steps = self .convert_to_input (kwargs .pop ("steps" , None ), is_optional = True , dtype = np . int64 )
54
54
inputs = [data , starts , ends , axes , steps ]
55
55
56
56
# pro-process inputs and attr
@@ -78,7 +78,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
78
78
return self .graph .make_node (op_type = "Slice" , inputs = inputs , attr = attr , name = name ,
79
79
outputs = outputs , shapes = shapes , dtypes = dtypes ).output [0 ]
80
80
81
- def convert_to_input (self , tensor , is_optional = False ):
81
+ def convert_to_input (self , tensor , is_optional = False , dtype = None ):
82
82
"""in ONNX, input shold come from node, so it must be a string"""
83
83
if is_optional and tensor is None :
84
84
return None
@@ -87,7 +87,7 @@ def convert_to_input(self, tensor, is_optional=False):
87
87
88
88
res = tensor
89
89
if isinstance (tensor , list ):
90
- res = self .graph .make_const (utils .make_name ("const_slice" ), np .array (tensor )).output [0 ]
90
+ res = self .graph .make_const (utils .make_name ("const_slice" ), np .array (tensor , dtype )).output [0 ]
91
91
92
92
utils .make_sure (isinstance (res , str ), "input is a dynamic input, so a str is needed" )
93
93
0 commit comments