Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair #7845

Merged
merged 13 commits into from
Apr 15, 2021
20 changes: 16 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def matmul(self, inputs, input_types):
b_shape = self.infer_shape_with_prelude(inputs_1)

# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 or len(b_shape) > 2:
if len(a_shape) > 2 and len(b_shape) > 2:
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
Expand All @@ -1606,18 +1606,30 @@ def matmul(self, inputs, input_types):
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output
elif len(a_shape) > 2:
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

# Otherwise a simple dense op will get the job done.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
if len(b_shape) == 1:
if len(b_shape) > 2:
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
else:
elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

out = _op.nn.dense(inputs_0, input_1)

if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])

# Reshape output into a N dimensional tensor when a or b dim > 2
if len(a_shape) > 2:
out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
elif len(b_shape) > 2:
out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
out = _op.reshape(_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]])

jcf94 marked this conversation as resolved.
Show resolved Hide resolved
return out

def expand(self, inputs, input_types):
Expand Down
28 changes: 24 additions & 4 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
return est


def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good addition!

"""Assert that the output of a compiled model matches with that of its
baseline."""
if isinstance(model_name, str):
Expand Down Expand Up @@ -219,6 +219,21 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at

assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)

if len(expected_ops) != 0:
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
found_op = dict.fromkeys(expected_ops, False)
def visit(op):
if isinstance(op, tvm.ir.op.Op):
if op.name in expected_ops:
found_op[op.name] = True

tvm.relay.analysis.post_order_visit(mod['main'].body, visit)

for op_name, is_found in enumerate(found_op):
if not is_found:
msg = "TVM Relay do not contain expected op [{}]"
raise AssertionError(msg.format(op_name))
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

del model_name
del baseline_model
torch.cuda.empty_cache()
Expand Down Expand Up @@ -3304,17 +3319,22 @@ def forward(self, *args):
# matrix x matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(4, 10)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.batch_matmul'])

# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])

# broadcasted matrix x batched matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(3, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=['nn.dense'])

# batched matrix x batched matrix
tensor1 = torch.randn(1, 12, 14, 64)
Expand Down