Skip to content

Commit

Permalink
Lint fix
Browse files Browse the repository at this point in the history
Re-triggle CI

Bug fix

Re-triggle CI

Re-triggle CI

Re-triggle CI
  • Loading branch information
jcf94 committed Jul 28, 2021
1 parent 8dac20a commit 0649ec8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,8 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
x.dtype == "int8"
and y.dtype == "int8"
and out_type.dtype == "int32"
and attrs["transpose_a"] == False
and attrs["transpose_b"] == True
and not attrs["transpose_a"]
and attrs["transpose_b"]
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True),
Expand All @@ -849,8 +849,8 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
if (
target.kind.name == "cuda"
and nvcc.have_tensorcore(target=target)
and attrs["transpose_a"] == False
and attrs["transpose_b"] == True
and not attrs["transpose_a"]
and attrs["transpose_b"]
):
x, y = inputs
_, M, K = get_const_tuple(x.shape)
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
or isinstance(K, tir.expr.Any)
or isinstance(N, tir.expr.Any)
):
# Dynamic shape cannot support alter op layout
return
# Dynamic shape do not support alter op layout now
return None

M = M.value
K = K.value
N = N.value
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def batch_matmul(
if oshape is None:
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
batch = (
tvm.tir.Any()
tvm.tir.expr.SizeVar("batch", "int32")
if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var)
else te.max(XB, YB)
)
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def batch_matmul(
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
if cfg.is_fallback:
if transpose_a:
_, K, M = get_const_tuple(tensor_a.shape)
else:
_, M, K = get_const_tuple(tensor_a.shape)
if transpose_b:
_, N, _ = get_const_tuple(tensor_b.shape)
else:
_, _, N = get_const_tuple(tensor_b.shape)
_default_batch_matmul_config(cfg, M, N, K)
return nn.batch_matmul(
tensor_a,
tensor_b,
Expand Down Expand Up @@ -199,6 +209,7 @@ def batch_matmul_cblas(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Compute batch_matmul using cblas"""
del out_dtype # Unused argument
return batch_matmul_blas_common(
cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas
)
Expand All @@ -215,6 +226,7 @@ def batch_matmul_mkl(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""Compute batch_matmul using mkl"""
del out_dtype # Unused argument
return batch_matmul_blas_common(
cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl
)
Expand Down

0 comments on commit 0649ec8

Please sign in to comment.