Skip to content

Commit

Permalink
support split k in profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 2eb1cf4 commit 6e4c7e1
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 46 deletions.
5 changes: 2 additions & 3 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def handle_conv2d(
out_dtype,
data_dtype,
weight_dtype,
split_k_slices,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down Expand Up @@ -290,6 +290,7 @@ def tune_cutlass_kernels(
mod,
sm,
use_3xtf32=True,
split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand Down Expand Up @@ -369,8 +370,6 @@ def tune_cutlass_kernels(
d_shape = arg0_shape
w_shape = arg1_shape

split_k_slices = [8]

new_attrs.update(
handle_conv2d(
conv2d_profiler,
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def __init__(self):
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
"""

def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
def emit(
self, operation, no_beta_scaling=False, residual_block_info=False, emit_reduction=False
):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand All @@ -248,12 +250,11 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
)

element_c = operation.C.element
element_c_gemm = element_c
use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and operation.split_k_slices > 1
# Gemm output always fp32 in wgrad with split k
element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c

if use_split_k_wgrad:
# split k, assumes fp32 accum. gemm output always fp32
element_c_gemm = DataType.f32
if emit_reduction:
epilogue_reduction = substitute_template(
self.epilogue_wgrad,
{
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self):
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
int split_k_slices = {{SplitK}};
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
Expand All @@ -100,7 +101,7 @@ def __init__(self):
options.dilation,
options.output_size(),
cutlass::conv::Mode::kCrossCorrelation,
1
split_k_slices
);
auto conv_kind = ImplicitGemm::kConvolutionalOperator;
Expand All @@ -115,13 +116,17 @@ def __init__(self):
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ?
cutlass::conv::SplitKMode::kParallel : cutlass::conv::SplitKMode::kSerial;
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
split_k_mode,
};
ImplicitGemm implicit_gemm_op;
Expand Down Expand Up @@ -166,6 +171,6 @@ def __init__(self):
"""
)

def emit(self, op_def, op_name):
src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
def emit(self, op_def, op_name, split_k_slices=1):
src = self.template.render(OperatorDef=op_def, OperatorName=op_name, SplitK=split_k_slices)
return src
14 changes: 10 additions & 4 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def create_conv2d_operator_with_epilogue(

name = op.procedural_name()
opdef = EmitConv2dInstance().emit(
op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info
op,
no_beta_scaling=no_beta_scaling,
residual_block_info=residual_block_info,
emit_reduction=split_k_slices > 1,
)

return name, opdef
Expand Down Expand Up @@ -142,17 +145,20 @@ def enumerate_conv2d_operators(
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
split_k_slice
split_k_slice,
)

ret.append(
{
"src": profiler_emitter.emit(kernel_emitter.emit(op), op.procedural_name()),
"src": profiler_emitter.emit(
kernel_emitter.emit(op), op.procedural_name(), split_k_slice
),
"name": op.procedural_name(),
"tile_description": tile,
"alignment": alignment,
"data_type": data_type,
"swizzle_functor": swizzling_functor,
"split_k_slices": split_k_slice,
}
)

Expand Down Expand Up @@ -203,7 +209,7 @@ def get_default(
data_type,
alignment,
swizzling_functor,
split_k_slices=1
split_k_slices=1,
)
return {"name": name, "opdef": opdef}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
out_dtype=attrs.out_dtype
out_dtype=attrs.out_dtype,
)

# infer shape of backward_weight
Expand Down
23 changes: 13 additions & 10 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, attrs.at("op_def"));
CutlassPrint(conv2d_decl, "using Operation_" + op_name +
" = cutlass::conv::device::ImplicitGemmConvolution<" +
op_name + ">;\n");
" = cutlass::conv::device::ImplicitGemmConvolution<" + op_name +
">;\n");
CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n");
CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n");
Expand Down Expand Up @@ -322,8 +322,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const bool use_split_k = op_name.find("splitk") != std::string::npos;

if (use_split_k) {
std::string split_k_slices = op_name.substr(op_name.find_last_not_of("0123456789"));
LOG(INFO) << "split_k : " << split_k_slices;
std::string split_k_slices = op_name.substr(op_name.find_last_not_of("0123456789") + 1);
CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices + ";\n");
} else {
CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n");
Expand Down Expand Up @@ -388,7 +387,11 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n");
}

CutlassPrint(conv2d_decl, "using ElementOutput = EpilogueOutputOp::ElementOutput ;\n");
if (use_split_k) {
CutlassPrint(conv2d_decl, "using ElementOutput = EpilogueOutputOp::ElementOutput;\n");
} else {
CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n");
}

std::string tensor_c_init = "{static_cast<ElementOutput*>(ptr_out), layout_C}";
if (has_residual_block) {
Expand All @@ -410,8 +413,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
} else {
CutlassPrint(conv2d_decl, " tensor_c,\n");
CutlassPrint(conv2d_decl, " tensor_d,\n");
Expand Down Expand Up @@ -442,9 +445,9 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");

if (use_split_k) {
CutlassPrint(
conv2d_decl,
"arguments.ref_D.reset(reinterpret_cast<ElementComputeEpilogue*>(workspace.get()), layout_D);\n\n");
CutlassPrint(conv2d_decl,
"arguments.ref_D.reset(reinterpret_cast<ElementComputeEpilogue*>(workspace.get()),"
" layout_D);\n\n");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
Expand Down
58 changes: 38 additions & 20 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,24 @@ def get_random_ndarray(shape, dtype):


def profile_and_build(
mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False, use_3xtf32=True
mod,
params,
sm,
split_k_slices=[1],
tmp_dir="./tmp",
lib_path="compile.so",
use_fast_math=False,
use_3xtf32=True,
):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
mod,
sm,
use_3xtf32=use_3xtf32,
split_k_slices=split_k_slices,
profile_all_alignments=False,
find_first_valid=True,
use_multiprocessing=False,
use_multiprocessing=True,
tmp_dir=tmp_dir,
)
with tvm.transform.PassContext(opt_level=3):
Expand All @@ -277,6 +285,7 @@ def profile_and_build_vm(
mod,
params,
sm,
split_k_slices=[1],
tmp_dir="./tmp",
lib_path="compile.so",
vmcode_path="vmcode.ro",
Expand All @@ -287,6 +296,7 @@ def profile_and_build_vm(
mod, num_cutlass_partition = tune_cutlass_kernels(
mod,
sm,
split_k_slices=split_k_slices,
use_3xtf32=use_3xtf32,
profile_all_alignments=False,
find_first_valid=True,
Expand Down Expand Up @@ -520,6 +530,7 @@ def verify_conv2d_common(
inputs,
params,
sm=80,
split_k_slices=[1],
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
Expand All @@ -530,7 +541,7 @@ def verify_conv2d_common(
):
if not has_cutlass():
return
if sm < 80 and data_dtype == "float32":
if sm < 80 and inputs[0].dtype == "float32":
return

mod_nchw = tvm.IRModule.from_expr(expr_nchw)
Expand All @@ -555,7 +566,7 @@ def verify_conv2d_common(
)

rt_mod, _, num_cutlass_partition = profile_and_build_func(
mod_weight_ohwi, params, sm, use_fast_math=use_fast_math
mod_weight_ohwi, params, sm, split_k_slices, use_fast_math=use_fast_math
)
out = get_output_func(rt_mod, input_names, inputs)

Expand Down Expand Up @@ -609,13 +620,16 @@ def verify_conv2d(
np_bias = get_random_ndarray((w_shape[0],), typ.dtype)
params = {"weight": np_weight, "bias": np_bias}

split_k_slices = [1]

return verify_conv2d_common(
expr_nchw,
expr_ref,
["data"],
[np_data],
params,
sm,
split_k_slices,
atol,
rtol,
use_cudnn_ref,
Expand All @@ -632,6 +646,7 @@ def verify_conv2d_backward_weight(
grad_shape,
data_shape,
sm=80,
split_k_slices=[1],
atol=1e-5,
rtol=1e-5,
use_cudnn_ref=False,
Expand All @@ -652,6 +667,7 @@ def verify_conv2d_backward_weight(
[np_grad, np_data],
params,
sm,
split_k_slices,
atol,
rtol,
use_cudnn_ref,
Expand Down Expand Up @@ -829,8 +845,8 @@ def test_conv2d_transpose():


def test_conv2d_backward_weight():
OC = 32
IC = 32
OC = 8
IC = 16
d_shape = (16, IC, 32, 32)
w_shape = (OC, IC, 3, 3)
dtype = "float16"
Expand All @@ -850,18 +866,20 @@ def test_conv2d_backward_weight():
weight_dtype=dtype,
)

verify_conv2d_backward_weight(
mod_nchw,
mod_nchw,
o_shape,
d_shape,
sm=80,
atol=1e-3,
rtol=1e-3,
use_cudnn_ref=False,
grad_dtype=dtype,
data_dtype=dtype,
)
for split_k_slices in [1, 8]:
verify_conv2d_backward_weight(
mod_nchw,
mod_nchw,
o_shape,
d_shape,
sm=80,
split_k_slices=[split_k_slices],
atol=1e-3,
rtol=1e-3,
use_cudnn_ref=False,
grad_dtype=dtype,
data_dtype=dtype,
)


def test_conv2d_bwd():
Expand Down Expand Up @@ -913,5 +931,5 @@ def test_conv2d_bwd():


if __name__ == "__main__":
# pytest.main([__file__])
test_conv2d_backward_weight()
pytest.main([__file__])
# test_conv2d_backward_weight()

0 comments on commit 6e4c7e1

Please sign in to comment.