diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index cec64f0af974c..4132937874b57 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -141,12 +141,13 @@ def create_gemm_operator( # TODO(masahi): A sensible way to pick reasonable default kernels DEFAULT_KERNELS = { 75: { - "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4", - "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4", + "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", }, + # align1 variants do not seem to be available for sm80 80: { - "float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4", - "float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4", + "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", }, } @@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm self.cache = {} - def check_align(self, op_name, M): + def check_align(self, op_name, M, K): """Filter out kernels that cannot be supported.""" aligns = re.findall(r"align[1|2|4|8]", op_name) assert len(aligns) == 1 + # The same alignment is used for all axes align = int(aligns[0][-1]) - if M % align != 0: - return False - return True + # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. + # See https://github.com/NVIDIA/cutlass/issues/362. + # When the above issue is resolved, we can remove the alignment check on M below. + return M % align == 0 and K % align == 0 def get_default(self, out_dtype, batched=False): """Return the default kernel for the requested architecture. @@ -194,7 +197,7 @@ def profile( ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, op_creator=partial(create_gemm_operator, batched=batched) ) - ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) + ops = list(filter(lambda op: self.check_align(op["name"], M, K), ops)) for op in ops: op["runtime"] = -1