Skip to content

Commit 8b2ee8b

Browse files
committed
Bug fix for tensorflow test
1 parent e3d9708 commit 8b2ee8b

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_conf
12561256
Dict of converted parameters stored in tvm.nd.NDArray format
12571257
"""
12581258
global TF_DEFAULT_CONFIGS
1259-
TF_DEFAULT_CONFIGS.update(convert_config)
1259+
if convert_config is not None:
1260+
TF_DEFAULT_CONFIGS.update(convert_config)
12601261

12611262
g = GraphProto()
12621263
mod, params = g.from_tensorflow(graph, layout, shape, outputs)

python/tvm/topi/cuda/batch_matmul.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _callback(op):
140140

141141

142142
@autotvm.register_topi_compute("batch_matmul_cublas.cuda")
143-
def batch_matmul_cublas(cfg, x, y, out_shape=None):
143+
def batch_matmul_cublas(cfg, x, y, out_shape=None, transpose_a=False, transpose_b=True):
144144
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
145145
data in batch.
146146
@@ -160,11 +160,17 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None):
160160
output : tvm.te.Tensor
161161
3-D with shape [batch, M, N]
162162
"""
163-
b, m, k = get_const_tuple(x.shape)
164-
b, n, k = get_const_tuple(y.shape)
163+
if transpose_a:
164+
b, k, m = get_const_tuple(x.shape)
165+
else:
166+
b, m, k = get_const_tuple(x.shape)
167+
if transpose_b:
168+
b, n, k = get_const_tuple(y.shape)
169+
else:
170+
b, k, n = get_const_tuple(y.shape)
165171
if all([isinstance(s, int) for s in [b, m, n, k]]):
166172
cfg.add_flop(b * m * k * n * 2)
167-
return cublas.batch_matmul(x, y, False, True)
173+
return cublas.batch_matmul(x, y, transa=transpose_a, transb=transpose_b)
168174

169175

170176
@autotvm.register_topi_schedule("batch_matmul_cublas.cuda")

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,10 +1812,10 @@ def _test_matmul(i, j, k, dtype, outer=None):
18121812
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
18131813
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
18141814
compare_tf_with_tvm(
1815-
[A_np, B_np], [A.name, B.name], result.name, {"use_dense": True}
1815+
[A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": True}
18161816
)
18171817
compare_tf_with_tvm(
1818-
[A_np, B_np], [A.name, B.name], result.name, {"use_dense": False}
1818+
[A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": False}
18191819
)
18201820

18211821

@@ -1835,10 +1835,16 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
18351835
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
18361836
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
18371837
compare_tf_with_tvm(
1838-
[A_np, B_np], [A.name, B.name], result.name, {"use_nt_batch_matmul": True}
1838+
[A_np, B_np],
1839+
[A.name, B.name],
1840+
result.name,
1841+
convert_config={"use_nt_batch_matmul": True},
18391842
)
18401843
compare_tf_with_tvm(
1841-
[A_np, B_np], [A.name, B.name], result.name, {"use_nt_batch_matmul": False}
1844+
[A_np, B_np],
1845+
[A.name, B.name],
1846+
result.name,
1847+
convert_config={"use_nt_batch_matmul": False},
18421848
)
18431849

18441850

@@ -1852,10 +1858,23 @@ def _test_batch_matmul_dynamic(
18521858

18531859
A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype)
18541860
B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype)
1855-
# for now, in TOPI, only cublas's implementation support dynamic shape
1861+
# for now, in TOPI, only llvm & cublas's implementation support dynamic shape
18561862
# TODO add more backends support in TOPI
18571863
compare_tf_with_tvm(
1858-
[A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"]
1864+
[A_np, B_np],
1865+
[A.name, B.name],
1866+
result.name,
1867+
mode="vm",
1868+
targets=["llvm", "cuda -libs=cublas"],
1869+
convert_config={"use_nt_batch_matmul": True},
1870+
)
1871+
compare_tf_with_tvm(
1872+
[A_np, B_np],
1873+
[A.name, B.name],
1874+
result.name,
1875+
mode="vm",
1876+
targets=["llvm", "cuda -libs=cublas"],
1877+
convert_config={"use_nt_batch_matmul": False},
18591878
)
18601879

18611880

@@ -1874,7 +1893,6 @@ def test_forward_batch_matmul():
18741893
_test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False)
18751894

18761895

1877-
@tvm.testing.requires_cuda
18781896
def test_forward_batch_matmul_dynamic():
18791897
_test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32")
18801898
_test_batch_matmul_dynamic(

0 commit comments

Comments
 (0)