Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,9 @@ def version_11(cls, ctx, node, **kwargs):
shapes=shapes, dtypes=dtypes, domain=constants.ONNX_DOMAIN, attr={'direction': direction})

if node.maybe_cast_input([supported, supported], type_map):
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
name=utils.make_name(node.name) + "_castback")
cast_back_node.set_attr("to", dtypes[0])
cast_back_node = ctx.insert_new_node_on_output(
"Cast", node.output[0], name=utils.make_name(node.name) + "_castback",
to=dtypes[0])
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
ctx.copy_shape(node.name, cast_back_node.output[0])

Expand Down
14 changes: 6 additions & 8 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,14 +637,13 @@ def version_1(cls, ctx, node, **kwargs):
origin_dtype = ctx.get_dtype(node.output[0])
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT,
onnx_pb.TensorProto.DOUBLE]:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
ctx.copy_shape(node.name, cast_node.output[0])

cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
name=utils.make_name(node.name) + "_castback")
cast_back_node.set_attr("to", origin_dtype)
name=utils.make_name(node.name) + "_castback",
to=origin_dtype)
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
ctx.copy_shape(node.name, cast_back_node.output[0])

Expand All @@ -667,14 +666,13 @@ def version_11(cls, ctx, node, **kwargs):
origin_dtype = ctx.get_dtype(node.output[0])
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
TensorProto.INT32, TensorProto.INT64]:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
cast_node.set_attr("to", TensorProto.FLOAT)
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.FLOAT)
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)
ctx.copy_shape(node.name, cast_node.output[0])

cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
name=utils.make_name(node.name) + "_castback")
cast_back_node.set_attr("to", origin_dtype)
name=utils.make_name(node.name) + "_castback",
to=origin_dtype)
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
ctx.copy_shape(node.name, cast_back_node.output[0])

Expand Down
37 changes: 17 additions & 20 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
"""cast int32 shape into int64 shape."""
name = node.input[input_number]

cast_node = ctx.insert_new_node_on_input(node, "Cast", name)
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
cast_node = ctx.insert_new_node_on_input(node, "Cast", name, to=onnx_pb.TensorProto.INT64)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
ctx.copy_shape(name, cast_node.output[0])

Expand All @@ -46,14 +45,14 @@ def _wrap_concat_with_cast(ctx, node):
output_name = node.output[0]
# cast each inputs to float
for i, inp in enumerate(node.inputs):
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i])
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i],
to=onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
next_nodes = ctx.find_output_consumers(node.output[0])
# cast output back to dtype unless the next op is a cast
if next_nodes[0].type != "Cast":
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name())
output_cast.set_attr("to", dtype)
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(output_name, output_cast.output[0])

Expand Down Expand Up @@ -157,15 +156,14 @@ def version_5(cls, ctx, node, **kwargs):
return

# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
ctx.copy_shape(node.output[0], input_cast.output[0])

# if the next node is already a cast we don't need to insert another one
next_nodes = ctx.find_output_consumers(node.output[0])
if len(next_nodes) != 1 or next_nodes[0].type != "Cast":
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name())
output_cast.set_attr("to", dtype)
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name(),
to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(node.output[0], output_cast.output[0])

Expand Down Expand Up @@ -739,16 +737,17 @@ def version_1(cls, ctx, node, **kwargs):
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
# override the previous cast
cast_node = node.inputs[0]
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
else:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0],
to=onnx_pb.TensorProto.FLOAT)
nodes.insert(0, cast_node)
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
ctx.copy_shape(node.input[0], cast_node.output[0])
# undo the cast afer slice
name = utils.make_name(node.name)
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
cast_node.set_attr("to", input_dtype)
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name,
to=input_dtype)
ctx.set_dtype(cast_node.output[0], input_dtype)
ctx.copy_shape(node.output[0], cast_node.output[0])
nodes.append(cast_node)
Expand Down Expand Up @@ -1180,8 +1179,7 @@ def version_1(cls, ctx, node, **kwargs):
if dtype == onnx_pb.TensorProto.INT64:
return
op_name = utils.make_name(node.name)
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name)
output_cast.set_attr("to", dtype)
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name, to=dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(node.output[0], output_cast.output[0])

Expand Down Expand Up @@ -1555,8 +1553,7 @@ def version_8(cls, ctx, node, **kwargs):

seq_len_dtype = ctx.get_dtype(node.input[1])
if seq_len_dtype != onnx_pb.TensorProto.INT64:
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
ctx.copy_shape(node.input[1], cast_node.output[0])

Expand Down Expand Up @@ -1762,8 +1759,8 @@ def version_11(cls, ctx, node, **kwargs):
# cast to int64 if needed
if dtypes[1] != onnx_pb.TensorProto.UINT64:
cast_node = ctx.insert_new_node_on_output("Cast", node.output[1],
name=utils.make_name(node.name) + "_cast")
cast_node.set_attr("to", dtypes[1])
name=utils.make_name(node.name) + "_cast",
to=dtypes[1])
ctx.set_dtype(cast_node.output[0], dtypes[1])
ctx.copy_shape(node.output[1], cast_node.output[0])
# FIXME: the indices in onnx are not the same as in tensorflow.
Expand Down
8 changes: 4 additions & 4 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
input_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
g.set_dtype(input_name, onnx_pb.TensorProto.FLOAT)
else:
cast_node = g.insert_new_node_on_input(op, "Cast", input_name)
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
cast_node = g.insert_new_node_on_input(op, "Cast", input_name,
to=onnx_pb.TensorProto.FLOAT)
g.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
g.copy_shape(input_name, cast_node.output[0])
cast_inserted.append(cast_node)
Expand All @@ -171,8 +171,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
name = utils.make_name(op.name)
logger.debug("insert cast back for node %s on output %s [dtype=%s]", op.name, output_name,
output_dtype)
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name)
output_cast.set_attr("to", output_dtype)
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name,
to=output_dtype)
g.set_dtype(output_cast.output[0], output_dtype)
g.copy_shape(output_name, output_cast.output[0])
cast_inserted.append(output_cast)
Expand Down