Skip to content

Commit

Permalink
conv2d profiler working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent 49ee61f commit ab114f5
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 35 deletions.
9 changes: 9 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def handle_conv2d(
d_shape,
w_shape,
out_shape,
padding,
strides,
dilation,
out_dtype,
profile_all,
use_multiprocessing,
Expand All @@ -198,6 +201,9 @@ def handle_conv2d(
d_shape,
w_shape,
out_shape,
padding,
strides,
dilation,
out_dtype,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
Expand Down Expand Up @@ -279,6 +285,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape,
arg1_shape,
annotator.signature["ret_shape"],
annotator.op_attrs.padding,
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
out_dtype,
profile_all,
use_multiprocessing,
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ def __init__(self):
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "helper.h"
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
{{OperatorDef}}
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<{{OperatorName}}>;
Expand Down
112 changes: 92 additions & 20 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
import re
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
from .gen_tensor_op import (
ProfilerEngine,
generate_sm75_tensor_op_1688,
generate_sm80_tensor_op_16816,
GENERATOR_FUNC_TABLE,
)
from .library import (
EpilogueFunctor,
Expand Down Expand Up @@ -115,12 +117,27 @@ def create_conv2d_operator(
return ret


DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align1",
"float32": "cutlass_tensorop_s1688fprop_optimized_f16_256x128_32x2_nhwc_align1",
},
80: {
"float16": "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align1",
"float32": "cutlass_tensorop_s1688fprop_optimized_f16_256x128_32x2_nhwc_align1",
},
}


class CutlassConv2DProfiler:
"""Profile all candidate kernels and select the best one."""

def __init__(self, sm, cutlass_path, binary_path):
self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path)
self.sm = sm
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

def get_default(self, out_dtype):
gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
Expand All @@ -129,29 +146,84 @@ def get_default(self, out_dtype):
data_type = gemm_profile_result["data_type"]
return create_conv2d_operator([tile_description], data_type, [alignment])[0]

def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
align = int(aligns[0][-1])
return all([dim % align == 0 for dim in [C, K]])

def profile(
self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False
self,
d_shape,
w_shape,
padding,
stride,
dilation,
out_shape,
out_dtype,
profile_all=True,
use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
B, _, _, IC = d_shape
OC, R, S, _ = w_shape
_, P, Q, _ = out_shape

M = B * P * Q
N = OC
K = R * S * IC

gemm_profile_result = self.gemm_profiler.profile(
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
)

tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]

out = create_conv2d_operator([tile_description], data_type, [alignment])[0]
print(out["src"])
return out
if True:
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
N, H, W, IC = d_shape
OC, R, S, _ = w_shape
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

for op in ops:
op["runtime"] = -1

if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

args = [
"--n=%d" % N,
"--h=%d" % H,
"--w=%d" % W,
"--k=%d" % OC,
"--c=%d" % IC,
"--r=%d" % R,
"--s=%d" % IC,
"--pad_h=%d" % padding[0],
"--pad_w=%d," % padding[1],
"--stride_h=%d" % stride[0],
"--stride_w=%d" % stride[0],
"--dilation_h=%d" % dilation[0],
"--dilation_w=%d" % dilation[0],
]
for op in ops:
out = self.engine.evaluate(op, args)
op["runtime"] = out
if out > 0 and profile_all is False:
break

valid_ops = filter(lambda op: op["runtime"] > 0, ops)
output = sorted(valid_ops, key=lambda i: i["runtime"])
# self.cache[(M, N, K)] = output[0]
return output[0]

else:
B, _, _, IC = d_shape
OC, R, S, _ = w_shape
_, P, Q, _ = out_shape

M = B * P * Q
N = OC
K = R * S * IC

gemm_profile_result = self.gemm_profiler.profile(
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
)

tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]

out = create_conv2d_operator([tile_description], data_type, [alignment])[0]
# print(out["src"])
return out
5 changes: 1 addition & 4 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ProfilerEngine,
generate_sm75_tensor_op_1688,
generate_sm80_tensor_op_16816,
GENERATOR_FUNC_TABLE
)
from .library import (
EpilogueFunctor,
Expand Down Expand Up @@ -132,10 +133,6 @@ def create_gemm_operator(
return ret


GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}


# TODO(masahi): A sensible way to pick reasonable default kernels
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def get_tile_descriptions(math_inst):
return sm75_kernels + sm80_kernels


GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}

class ProfilerEngine:
"""Compile and run a given profiler executable."""

Expand All @@ -185,6 +190,7 @@ def _compile(self, op):
fi.write(op["src"])
fi.close()
cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
print(cmd)
os.system(cmd)
os.unlink(fi.name)

Expand All @@ -204,13 +210,10 @@ def evaluate(self, op, args):
if not os.path.exists(opath):
self._compile(op)
cmd = [opath]
if args is not None:
cmd.append(str(args[0]))
cmd.append(str(args[1]))
cmd.append(str(args[2]))
if len(args) > 3:
cmd.append(str(args[3]))
for arg in args:
cmd.append(str(arg))
try:
print("".join(cmd))
sp = subprocess.run(cmd, capture_output=True, check=True)
rt = float(sp.stdout)
logger.info("%s, %f", op_name, rt)
Expand Down
10 changes: 6 additions & 4 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16")


def get_conv2d_nchw(d_shape, w_shape):
def get_conv2d_nchw(d_shape, w_shape, out_dtype="float16"):
data = relay.var("data", shape=d_shape, dtype="float16")
weight = relay.var("weight", shape=w_shape, dtype="float16")
out_channel = w_shape[0]
Expand All @@ -121,15 +121,15 @@ def get_conv2d_nchw(d_shape, w_shape):
kernel_size=(3, 3),
channels=out_channel,
padding=(1, 1),
out_dtype="float16",
out_dtype=out_dtype,
)
)


def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir
mod, sm, profile_all=True, use_multiprocessing=True, tmp_dir=tmp_dir
)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="cuda", params=params)
Expand Down Expand Up @@ -390,6 +390,7 @@ def test_conv2d():
run_benchmark=False,
)

return
d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
dyn_batch_shape = (relay.Any(),) + d_shape[1:]
Expand All @@ -403,4 +404,5 @@ def test_conv2d():


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_conv2d()

0 comments on commit ab114f5

Please sign in to comment.