Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 14, 2021
1 parent 3d3d24f commit 7e43d42
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
1 change: 0 additions & 1 deletion python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def handle_conv2d(
out = cutlass_profiler.profile(
d_shape,
w_shape,
out_shape,
padding,
strides,
dilation,
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def create_conv2d_operator(
swizzling_functor_,
)

# TODO(masahi): Add profiler source here
op_entry["opdef"] = kernel_emitter.emit(op)
op_entry["op"] = op
op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name())
Expand Down Expand Up @@ -143,7 +142,6 @@ def profile(
self,
d_shape,
w_shape,
out_shape,
padding,
stride,
dilation,
Expand Down
10 changes: 6 additions & 4 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_conv2d_nchw(d_shape, w_shape, padding, out_dtype="float16"):
def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
mod, sm, profile_all=True, use_multiprocessing=True, tmp_dir=tmp_dir
mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir
)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="cuda", params=params)
Expand Down Expand Up @@ -376,7 +376,8 @@ def test_conv2d():
for IC in [3, 16]:
d_shape = (16, IC, 32, 32)
w_shape = (32, IC, 3, 3)
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
padding = (1, 1)
mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding)

verify_conv2d(
mod_nchw,
Expand All @@ -392,10 +393,11 @@ def test_conv2d():

d_shape = (16, 16, 32, 32)
w_shape = (32, 16, 3, 3)
padding = (1, 1)
dyn_batch_shape = (relay.Any(),) + d_shape[1:]

mod_nchw = get_conv2d_nchw(d_shape, w_shape)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)
mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding)
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape, padding)

verify_conv2d(
mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
Expand Down

0 comments on commit 7e43d42

Please sign in to comment.