Skip to content

Commit 64423cd

Browse files
authored
Merge pull request onnx#526 from zhijxu-MS/bert_bug
fix bug
2 parents 61ff34f + 95d203a commit 64423cd

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tf2onnx/graph_builder.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
4747
# input sequence should be "data", "starts", "ends", "axes", "steps"
4848
attr = {}
4949
data = self.convert_to_input(kwargs.pop("data"))
50-
starts = self.convert_to_input(kwargs.pop("starts"))
51-
ends = self.convert_to_input(kwargs.pop("ends"))
52-
axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True)
53-
steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True)
50+
starts = self.convert_to_input(kwargs.pop("starts"), dtype=np.int64)
51+
ends = self.convert_to_input(kwargs.pop("ends"), dtype=np.int64)
52+
axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True, dtype=np.int64)
53+
steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True, dtype=np.int64)
5454
inputs = [data, starts, ends, axes, steps]
5555

5656
# pro-process inputs and attr
@@ -78,7 +78,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
7878
return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
7979
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
8080

81-
def convert_to_input(self, tensor, is_optional=False):
81+
def convert_to_input(self, tensor, is_optional=False, dtype=None):
8282
"""in ONNX, input shold come from node, so it must be a string"""
8383
if is_optional and tensor is None:
8484
return None
@@ -87,7 +87,7 @@ def convert_to_input(self, tensor, is_optional=False):
8787

8888
res = tensor
8989
if isinstance(tensor, list):
90-
res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor)).output[0]
90+
res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor, dtype)).output[0]
9191

9292
utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed")
9393

0 commit comments

Comments
 (0)