Skip to content

Commit ab75b58

Browse files
authored
[Bug][Relay] fix relay frontend pytorch op addmm bug (#15294)
* Update pytorch.py fix relay frontend pytorch op: addmm calculation formula error. bug: out = input + alpha * beta * mat1 @ mat2 fix bug: out = beta * input + alpha * mat1 @ mat2 * add relay frontend pytorch addmm op test * fix relay frontend pytorch op addmm * add relay frontend pytorch addmm op test
1 parent ae3de3d commit ab75b58

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,24 +1549,22 @@ def flatten(self, inputs, input_types):
15491549
def addmm(self, inputs, input_types):
15501550
input_mat = inputs[0]
15511551
mat1 = inputs[1]
1552-
data_type = input_types[1]
15531552
mat2 = inputs[2]
1554-
15551553
beta = inputs[3]
15561554
alpha = inputs[4]
1555+
data_type = input_types[1]
1556+
1557+
transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0])
1558+
units = self.infer_shape(transposed_mat2)[0]
1559+
dense_out = _op.nn.dense(mat1, transposed_mat2, units=units)
15571560

15581561
if not isinstance(alpha, _expr.Expr) and alpha != 1:
15591562
alpha = _create_typed_const(alpha, data_type)
1560-
mat1 *= alpha
1563+
dense_out *= alpha
15611564

15621565
if not isinstance(beta, _expr.Expr) and beta != 1:
15631566
beta = _create_typed_const(beta, data_type)
1564-
mat2 *= beta
1565-
1566-
transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0])
1567-
1568-
units = self.infer_shape(transposed_mat2)[0]
1569-
dense_out = _op.nn.dense(mat1, transposed_mat2, units=units)
1567+
input_mat *= beta
15701568

15711569
return dense_out + input_mat
15721570

tests/python/frontend/pytorch/test_forward.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5260,6 +5260,18 @@ def test_weight_norm():
52605260
verify_model(linear_wn.eval().float(), input_data_linear)
52615261

52625262

5263+
@tvm.testing.uses_gpu
5264+
def test_addmm():
5265+
def test_fn(alpha, beta):
5266+
return lambda inp, batch1, batch2: torch.addmm(inp, batch1, batch2, beta=beta, alpha=alpha)
5267+
5268+
M = torch.randn(3, 5)
5269+
batch1 = torch.randn(3, 4)
5270+
batch2 = torch.randn(4, 5)
5271+
5272+
verify_model(test_fn(0.4, 0.8), [M, batch1, batch2])
5273+
5274+
52635275
@tvm.testing.uses_gpu
52645276
def test_baddbmm():
52655277
def test_fn(alpha, beta):

0 commit comments

Comments
 (0)