Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions coremltools/converters/mil/frontend/torch/test/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from collections import OrderedDict

import numpy as np
Expand All @@ -18,6 +19,7 @@
flatten_graph_output_values,
transform_inplace_ops,
)
import coremltools as ct


def _build_flattening_test_graph():
Expand Down Expand Up @@ -78,15 +80,18 @@ def _build_flattening_test_graph():


class TestTorchPasses:
"""Class containing tests for InternalTorchIR optimization passes.
"""
Class containing tests for InternalTorchIR optimization passes.
"""

@pytest.fixture
def set_random_seeds(self):
torch.manual_seed(1)
np.random.seed(1)

def test_flatten_input_values(self):

@staticmethod
def test_flatten_input_values():
graph = _build_flattening_test_graph()

flatten_graph_input_values(graph)
Expand All @@ -107,7 +112,9 @@ def test_flatten_input_values(self):
# next op.
np.testing.assert_equal(graph.nodes[1].outputs[0], graph.nodes[2].inputs[0])

def test_flatten_output_values(self):

@staticmethod
def test_flatten_output_values():
graph = _build_flattening_test_graph()

flatten_graph_output_values(graph)
Expand All @@ -119,7 +126,9 @@ def test_flatten_output_values(self):
np.testing.assert_equal(graph.outputs[1], graph.nodes[1].outputs[0])
np.testing.assert_equal(graph.outputs[2], graph.nodes[1].outputs[1])

def test_transform_inplace_ops_graph(self):

@staticmethod
def test_transform_inplace_ops_graph():
# The test graph is:
# graph(
# %x : Tensor[1],
Expand Down Expand Up @@ -171,7 +180,9 @@ def test_transform_inplace_ops_graph(self):
np.testing.assert_equal(len(graph.outputs), 1)
np.testing.assert_equal(graph.outputs[0], graph.nodes[-1].outputs[0])

def test_transform_inplace_ops_loop(self):

@staticmethod
def test_transform_inplace_ops_loop():
# The test graph is:
# graph(
# %x : Tensor[1],
Expand Down Expand Up @@ -264,8 +275,10 @@ def test_transform_inplace_ops_loop(self):
# That graph output should now be the output of the graph.
np.testing.assert_equal(loop_node.outputs[0], graph.outputs[0])


@staticmethod
@pytest.mark.xfail(reason="rdar://64235006")
def test_transform_inplace_ops_if(self):
def test_transform_inplace_ops_if():
# The test graph is:
# graph(
# %x : Tensor[1],
Expand Down Expand Up @@ -369,3 +382,26 @@ def test_transform_inplace_ops_if(self):
np.testing.assert_equal(if_node.name, if_node.outputs[0])
# The graph output should be the if op output.
np.testing.assert_equal(if_node.outputs[0], graph.outputs[0])


@staticmethod
def test_inpace_op_from_cast():
class Net(torch.nn.Module):
def forward(self, x):
y = torch.empty(x.shape).to(torch.int32)
y.fill_(0.2)
return y

shape = (2, 3)
x = torch.rand(*shape)
traced_fn = torch.jit.trace(Net(), x).eval()

ct_model = ct.convert(
traced_fn,
inputs=[ct.TensorType(shape=shape)],
outputs=[ct.TensorType(name="y", dtype=np.int32)],
source="pytorch",
)
y_cm = ct_model.predict({'x': x})['y']

assert((y_cm == np.zeros(shape)).all())
9 changes: 9 additions & 0 deletions coremltools/converters/mil/frontend/torch/torchir_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse):
node_sequence.append(node)
tensor_to_node_sequence_mapping[node_output] = node_sequence

if node.kind == "to":
node_input = node.inputs[0]
if node_input in tensor_to_node_sequence_mapping:
# update the mapping
node_output = node.outputs[0]
val = tensor_to_node_sequence_mapping[node_input]
del tensor_to_node_sequence_mapping[node_input]
tensor_to_node_sequence_mapping[node_output] = val

if node.kind in ("copy_", "fill_"):
node_input = node.inputs[0]
if node_input not in tensor_to_node_sequence_mapping:
Expand Down