Skip to content

Commit

Permalink
check align on N dim
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent 308c4da commit 6b780db
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, sm, cutlass_path, binary_path):
self.sm = sm
self.cache = {}

def check_align(self, op_name, M, K):
def check_align(self, op_name, M, N, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
Expand All @@ -170,7 +170,7 @@ def check_align(self, op_name, M, K):
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return M % align == 0 and K % align == 0
return all([dim % align == 0 for dim in [M, N, K]])

def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
Expand All @@ -197,7 +197,7 @@ def profile(
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
)
ops = list(filter(lambda op: self.check_align(op["name"], M, K), ops))
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))

for op in ops:
op["runtime"] = -1
Expand Down
2 changes: 2 additions & 0 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def verify_batch_matmul(
def test_dense():
verify_dense(get_dense(M, N, K), M, N, K)
verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
# Test align1 case
verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K)


def test_dense_bias():
Expand Down

0 comments on commit 6b780db

Please sign in to comment.