Skip to content

Commit

Permalink
add dtype and layout check in parttern match
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent 7743cc6 commit 6cdf205
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,61 @@ def make_batch_matmul_pattern():


def make_conv2d_pattern():
# TODO(masahi): Check layout and alignment
return is_op("nn.conv2d")(wildcard(), wildcard())


def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16"


def check_gemm(call):
"""Check if the given dense workload can be offloaded to CUTLASS."""
lhs = call.args[0].checked_type
rhs = call.args[1].checked_type
return check_dtype(lhs, rhs)


def check_batch_matmul(call):
"""Check if the given batch_matmul workload can be offloaded to CUTLASS."""
transpose_a = call.attrs.transpose_a
transpose_b = call.attrs.transpose_b
return check_gemm(call) and transpose_a == False and transpose_b == True


def check_conv2d(call):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
data_layout = call.attrs.data_layout
kernel_layout = call.attrs.kernel_layout
data = call.args[0].checked_type
weight = call.args[1].checked_type
return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight)


def partition_for_cutlass(mod):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None))
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"))
dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"))
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm)
dense_bias_gelu_fp16_pat = (
"cutlass.dense_bias_gelu_fp16",
make_gemm_pattern(True, "gelu"),
check_gemm,
)
dense_bias_gelu_fp32_pat = (
"cutlass.dense_bias_gelu_fp32",
make_gemm_pattern(True, "gelu", out_dtype="float32"),
check_gemm,
)
cutlass_patterns = [
dense_bias_gelu_fp16_pat,
dense_bias_gelu_fp32_pat,
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern()),
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
# TODO(masahi): Add more conv2d patterns
("cutlass.conv2d", make_conv2d_pattern()),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
Expand Down

0 comments on commit 6cdf205

Please sign in to comment.