Skip to content
Merged
2 changes: 1 addition & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4706,7 +4706,7 @@ def func(filter_val, out_backprop_val, batch_dim):
def graph_validator(g):
for n in g.get_nodes():
if n.type == 'ConvTranspose':
return "output_shape" in n.attr
return "pads" in n.attr or "output_shape" in n.attr
return False
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val, _INPUT2: batch_dim_val},
graph_validator=graph_validator)
Expand Down
34 changes: 30 additions & 4 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,12 @@ def version_1(cls, ctx, node, **kwargs):
input_dims = input_shape[1:1+spatial]
else:
input_dims = input_shape[2:2+spatial]
input_dims_known = -1 not in input_dims
output_shape_orig = node.output_shapes

# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
# output_shape is explicitly specified here and then converted to explicit pads.
output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0])
if output_shape is not None:
#output_shape = ctx.get_shape(node.output[0])
if is_channels_last(node):
new_output_shape = [output_shape[1], output_shape[2]]
if spatial == 3:
Expand All @@ -453,8 +453,34 @@ def version_1(cls, ctx, node, **kwargs):
utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known")
utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)),
"output dims cannot be smaller than input dims.")

node.set_attr("output_shape", new_output_shape)
if -1 in input_dims:
node.set_attr("output_shape", new_output_shape)
else:
if "strides" in node.attr:
strides = parse_dims_attr(node, node.get_attr("strides").ints, spatial)
else:
strides = [1] * spatial
if "dilations" in node.attr:
dilations = parse_dims_attr(node, node.get_attr("dilations").ints, spatial)
else:
dilations = [1] * spatial
if "output_padding" in node.attr:
output_padding = parse_dims_attr(node, node.get_attr("output_padding").ints, spatial)
else:
output_padding = [0] * spatial
kernel_shape = parse_dims_attr(node, node.get_attr("kernel_shape").ints, spatial)
total_padding = [-1] * spatial
pads = [1] * (spatial * 2)
for i in range(spatial):
total_padding[i] = (strides[i] * (input_dims[i] - 1) + output_padding[i]
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
- new_output_shape[i])
start_i = i
end_i = i + spatial
pads[start_i] = int(total_padding[i] / 2)
pads[end_i] = total_padding[i] - pads[start_i]
node.set_attr("pads", pads)
node.set_attr("auto_pad", "NOTSET")
else:
utils.make_sure(ctx.opset >= 10, "Opset 10 needed for Conv Backprop Input with non-constant shape")
strides = parse_dims_attr(node, node.get_attr('strides').ints, spatial)
Expand Down