Skip to content

Commit

Permalink
add cache
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2021
1 parent f7d17a1 commit 6db7172
Showing 1 changed file with 26 additions and 31 deletions.
57 changes: 26 additions & 31 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,13 @@ def create_conv2d_operator(
return ret


DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align1",
"float32": "cutlass_tensorop_s1688fprop_optimized_f16_256x128_32x2_nhwc_align1",
},
80: {
"float16": "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align1",
"float32": "cutlass_tensorop_s1688fprop_optimized_f16_256x128_32x2_nhwc_align1",
},
}


class CutlassConv2DProfiler:
"""Profile all candidate kernels and select the best one."""

def __init__(self, sm, cutlass_path, binary_path):
self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path)
self.sm = sm
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm
assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

Expand Down Expand Up @@ -169,9 +157,28 @@ def profile(
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
N, H, W, IC = d_shape
OC, R, S, _ = w_shape
workload = (
N,
H,
W,
IC,
OC,
R,
S,
padding[0],
padding[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
)

if workload in self.cache:
return self.cache[workload]

ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

for op in ops:
Expand All @@ -180,28 +187,16 @@ def profile(
if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

args = [
"--n=%d" % N,
"--h=%d" % H,
"--w=%d" % W,
"--k=%d" % OC,
"--c=%d" % IC,
"--r=%d" % R,
"--s=%d" % S,
"--pad_h=%d" % padding[0],
"--pad_w=%d" % padding[1],
"--stride_h=%d" % stride[0],
"--stride_w=%d" % stride[1],
"--dilation_h=%d" % dilation[0],
"--dilation_w=%d" % dilation[1],
]
args =("--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d --pad_w=%d "
"--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d") % workload

for op in ops:
out = self.engine.evaluate(op, args)
out = self.engine.evaluate(op, args.split(" "))
op["runtime"] = out
if out > 0 and profile_all is False:
break

valid_ops = filter(lambda op: op["runtime"] > 0, ops)
output = sorted(valid_ops, key=lambda i: i["runtime"])
# self.cache[(M, N, K)] = output[0]
self.cache[workload] = output[0]
return output[0]

0 comments on commit 6db7172

Please sign in to comment.