Skip to content

Commit

Permalink
use align1 kernel for unusual channel cases (IC = 3 etc)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent 6cdf205 commit ffce47d
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def create_gemm_operator(
# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
},
# align1 variants do not seem to be available for sm80
80: {
"float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
"float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
},
}

Expand All @@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path):
self.sm = sm
self.cache = {}

def check_align(self, op_name, M):
def check_align(self, op_name, M, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
# The same alignment is used for all axes
align = int(aligns[0][-1])
if M % align != 0:
return False
return True
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return M % align == 0 and K % align == 0

def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
Expand All @@ -194,7 +197,7 @@ def profile(
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
ops = list(filter(lambda op: self.check_align(op["name"], M, K), ops))

for op in ops:
op["runtime"] = -1
Expand Down

0 comments on commit ffce47d

Please sign in to comment.