Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair #7845

Merged
merged 13 commits into from
Apr 15, 2021
24 changes: 19 additions & 5 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def matmul(self, inputs, input_types):
b_shape = self.infer_shape_with_prelude(inputs_1)

# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 or len(b_shape) > 2:
if len(a_shape) > 2 and len(b_shape) > 2:
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
Expand All @@ -1606,18 +1606,32 @@ def matmul(self, inputs, input_types):
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output
elif len(a_shape) > 2:
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

# Otherwise a simple dense op will get the job done.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
if len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
else:
if len(b_shape) > 2:
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)

out = _op.nn.dense(inputs_0, input_1)

if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])

# Reshape output into a N dimensional tensor when a or b dim > 2
if len(a_shape) > 2:
out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
elif len(b_shape) > 2:
out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
out = _op.reshape(
_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
)

return out

def expand(self, inputs, input_types):
Expand Down
31 changes: 27 additions & 4 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
return est


def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
def verify_model(
model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]
):
"""Assert that the output of a compiled model matches with that of its
baseline."""
if isinstance(model_name, str):
Expand Down Expand Up @@ -219,6 +221,20 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at

assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)

if expected_ops:

def visit(op):
if isinstance(op, tvm.ir.op.Op):
if op.name in expected_ops:
expected_ops.remove(op.name)

tvm.relay.analysis.post_order_visit(mod["main"].body, visit)

if expected_ops:
msg = "TVM Relay do not contain expected ops {}"
raise AssertionError(msg.format(expected_ops))

del model_name
del baseline_model
torch.cuda.empty_cache()
Expand Down Expand Up @@ -3304,17 +3320,24 @@ def forward(self, *args):
# matrix x matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(4, 10)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# broadcasted matrix x batched matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(3, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# batched matrix x batched matrix
tensor1 = torch.randn(1, 12, 14, 64)
Expand Down