diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c3a8fdc1ad8ca..90b7f9320c7f9 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -185,6 +185,9 @@ def handle_conv2d( d_shape, w_shape, out_shape, + padding, + strides, + dilation, out_dtype, profile_all, use_multiprocessing, @@ -198,6 +201,9 @@ def handle_conv2d( d_shape, w_shape, out_shape, + padding, + strides, + dilation, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing, @@ -279,6 +285,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t arg0_shape, arg1_shape, annotator.signature["ret_shape"], + annotator.op_attrs.padding, + annotator.op_attrs.strides, + annotator.op_attrs.dilation, out_dtype, profile_all, use_multiprocessing, diff --git a/python/tvm/contrib/cutlass/conv2d_profiler.py b/python/tvm/contrib/cutlass/conv2d_profiler.py index 2e4ef4f056afb..0ece2c2bf6cc2 100644 --- a/python/tvm/contrib/cutlass/conv2d_profiler.py +++ b/python/tvm/contrib/cutlass/conv2d_profiler.py @@ -33,7 +33,16 @@ def __init__(self): #include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/tensor_fill.h" -#include "helper.h" + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } {{OperatorDef}} using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<{{OperatorName}}>; diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index d89efa182fc4c..f51d317f904d6 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name """Conv2d kernel generator and profiler for CUTLASS.""" +import re from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter @@ -23,6 +24,7 @@ ProfilerEngine, generate_sm75_tensor_op_1688, generate_sm80_tensor_op_16816, + GENERATOR_FUNC_TABLE, ) from .library import ( EpilogueFunctor, @@ -115,12 +117,27 @@ def create_conv2d_operator( return ret +DEFAULT_KERNELS = { + 75: { + "float16": "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align1", + "float32": "cutlass_tensorop_s1688fprop_optimized_f16_256x128_32x2_nhwc_align1", + }, + 80: { + "float16": "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align1", + "float32": "cutlass_tensorop_s1688fprop_optimized_f16_256x128_32x2_nhwc_align1", + }, +} + + class CutlassConv2DProfiler: """Profile all candidate kernels and select the best one.""" def __init__(self, sm, cutlass_path, binary_path): self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path) self.sm = sm + assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm + self.engine = ProfilerEngine(sm, cutlass_path, binary_path) + self.cache = {} def get_default(self, out_dtype): gemm_profile_result = self.gemm_profiler.get_default(out_dtype) @@ -129,29 +146,84 @@ def get_default(self, out_dtype): data_type = gemm_profile_result["data_type"] return create_conv2d_operator([tile_description], data_type, [alignment])[0] + def check_align(self, op_name, C, K): + """Filter out kernels that cannot be supported.""" + aligns = re.findall(r"align[1|2|4|8]", op_name) + assert len(aligns) == 1 + align = int(aligns[0][-1]) + return all([dim % align == 0 for dim in [C, K]]) + def profile( - self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False + self, + d_shape, + w_shape, + padding, + stride, + dilation, + out_shape, + out_dtype, + profile_all=True, + use_multiprocessing=False, ): """Profile and select the best kernel from candidate kernels. If profile_all is False, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ - B, _, _, IC = d_shape - OC, R, S, _ = w_shape - _, P, Q, _ = out_shape - - M = B * P * Q - N = OC - K = R * S * IC - - gemm_profile_result = self.gemm_profiler.profile( - M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing - ) - - tile_description = gemm_profile_result["tile_description"] - alignment = gemm_profile_result["alignment"] - data_type = gemm_profile_result["data_type"] - - out = create_conv2d_operator([tile_description], data_type, [alignment])[0] - print(out["src"]) - return out + if True: + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator) + N, H, W, IC = d_shape + OC, R, S, _ = w_shape + ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) + + for op in ops: + op["runtime"] = -1 + + if profile_all: + self.engine.compile_all(ops, use_multiprocessing) + + args = [ + "--n=%d" % N, + "--h=%d" % H, + "--w=%d" % W, + "--k=%d" % OC, + "--c=%d" % IC, + "--r=%d" % R, + "--s=%d" % IC, + "--pad_h=%d" % padding[0], + "--pad_w=%d," % padding[1], + "--stride_h=%d" % stride[0], + "--stride_w=%d" % stride[0], + "--dilation_h=%d" % dilation[0], + "--dilation_w=%d" % dilation[0], + ] + for op in ops: + out = self.engine.evaluate(op, args) + op["runtime"] = out + if out > 0 and profile_all is False: + break + + valid_ops = filter(lambda op: op["runtime"] > 0, ops) + output = sorted(valid_ops, key=lambda i: i["runtime"]) + # self.cache[(M, N, K)] = output[0] + return output[0] + + else: + B, _, _, IC = d_shape + OC, R, S, _ = w_shape + _, P, Q, _ = out_shape + + M = B * P * Q + N = OC + K = R * S * IC + + gemm_profile_result = self.gemm_profiler.profile( + M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + ) + + tile_description = gemm_profile_result["tile_description"] + alignment = gemm_profile_result["alignment"] + data_type = gemm_profile_result["data_type"] + + out = create_conv2d_operator([tile_description], data_type, [alignment])[0] + # print(out["src"]) + return out diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 58d690f8191b4..5eb649df774b0 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -24,6 +24,7 @@ ProfilerEngine, generate_sm75_tensor_op_1688, generate_sm80_tensor_op_16816, + GENERATOR_FUNC_TABLE ) from .library import ( EpilogueFunctor, @@ -132,10 +133,6 @@ def create_gemm_operator( return ret -GENERATOR_FUNC_TABLE = { - 75: generate_sm75_tensor_op_1688, - 80: generate_sm80_tensor_op_16816, -} # TODO(masahi): A sensible way to pick reasonable default kernels diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index cc228737cefc6..6426ed5b2161d 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -159,6 +159,11 @@ def get_tile_descriptions(math_inst): return sm75_kernels + sm80_kernels +GENERATOR_FUNC_TABLE = { + 75: generate_sm75_tensor_op_1688, + 80: generate_sm80_tensor_op_16816, +} + class ProfilerEngine: """Compile and run a given profiler executable.""" @@ -185,6 +190,7 @@ def _compile(self, op): fi.write(op["src"]) fi.close() cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) + print(cmd) os.system(cmd) os.unlink(fi.name) @@ -204,13 +210,10 @@ def evaluate(self, op, args): if not os.path.exists(opath): self._compile(op) cmd = [opath] - if args is not None: - cmd.append(str(args[0])) - cmd.append(str(args[1])) - cmd.append(str(args[2])) - if len(args) > 3: - cmd.append(str(args[3])) + for arg in args: + cmd.append(str(arg)) try: + print("".join(cmd)) sp = subprocess.run(cmd, capture_output=True, check=True) rt = float(sp.stdout) logger.info("%s, %f", op_name, rt) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 585b42a214259..f5f30b8d0dece 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -110,7 +110,7 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"): return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16") -def get_conv2d_nchw(d_shape, w_shape): +def get_conv2d_nchw(d_shape, w_shape, out_dtype="float16"): data = relay.var("data", shape=d_shape, dtype="float16") weight = relay.var("weight", shape=w_shape, dtype="float16") out_channel = w_shape[0] @@ -121,7 +121,7 @@ def get_conv2d_nchw(d_shape, w_shape): kernel_size=(3, 3), channels=out_channel, padding=(1, 1), - out_dtype="float16", + out_dtype=out_dtype, ) ) @@ -129,7 +129,7 @@ def get_conv2d_nchw(d_shape, w_shape): def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( - mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir + mod, sm, profile_all=True, use_multiprocessing=True, tmp_dir=tmp_dir ) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target="cuda", params=params) @@ -390,6 +390,7 @@ def test_conv2d(): run_benchmark=False, ) + return d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) dyn_batch_shape = (relay.Any(),) + d_shape[1:] @@ -403,4 +404,5 @@ def test_conv2d(): if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_conv2d()