Skip to content

Commit

Permalink
fp32 output works
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 7a51995 commit 30df1bd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(self):
${element_c},
4,
float,
float,
float
>"""

self.template = """
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_conv2d_operator_with_epilogue(
data_type,
alignment,
swizzling_functor,
split_k_slices=1,
split_k_slices=8,
):
"""
Instantiate a cutlass kernel from the given configuration,
Expand Down Expand Up @@ -109,7 +109,7 @@ def enumerate_conv2d_operators(
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity4,
split_k_slices=[1],
split_k_slices=[8],
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
Expand Down
49 changes: 23 additions & 26 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,17 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
"cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);\n");
// Check the problem size is supported or not
CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");

if (use_split_k) {
CutlassPrint(
conv2d_decl,
"arguments.ref_D.reset(reinterpret_cast<ElementCompute*>(workspace.get()), layout_D);\n");
"arguments.ref_D.reset(reinterpret_cast<ElementComputeEpilogue*>(workspace.get()), layout_D);\n\n");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");

if (use_split_k) {
CutlassPrint(
Expand All @@ -453,47 +453,44 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,

// Launch initialized CUTLASS kernel
CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, " ReductionDevice reduction_op;\n");
CutlassPrint(conv2d_decl, "ReductionDevice reduction_op;\n");
CutlassPrint(conv2d_decl,
" const static cutlass::conv::Operator kConvolutionalOperator = "
"const static cutlass::conv::Operator kConvolutionalOperator = "
"Conv2d::kConvolutionalOperator;\n");
CutlassPrint(conv2d_decl, " typename ReductionDevice::Arguments reduction_args(\n");
CutlassPrint(conv2d_decl, "typename ReductionDevice::Arguments reduction_args(\n");
CutlassPrint(conv2d_decl,
" cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, "
"cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, "
"problem_size).mn(),\n");
CutlassPrint(conv2d_decl, " problem_size.split_k_slices,\n");
CutlassPrint(conv2d_decl, "problem_size.split_k_slices,\n");
CutlassPrint(conv2d_decl,
" cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, "
"cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, "
"problem_size),\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, "{\n");
CutlassPrint(conv2d_decl,
" reinterpret_cast<Conv2d::ElementAccumulator*> (workspace.get()),\n");
" reinterpret_cast<Conv2d::ElementAccumulator*> (workspace.get()),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
"kTensorCStrideIdx])\n");
CutlassPrint(conv2d_decl, " },\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, " tensor_d.data(),\n");
CutlassPrint(conv2d_decl, "},\n");
CutlassPrint(conv2d_decl, "{\n");
CutlassPrint(conv2d_decl, "tensor_d.data(),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::"
"kTensorCStrideIdx])\n");
CutlassPrint(conv2d_decl, " },\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, " tensor_c.data(),\n");
CutlassPrint(conv2d_decl, "},\n");
CutlassPrint(conv2d_decl, "{\n");
CutlassPrint(conv2d_decl, "tensor_c.data(),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
"kTensorCStrideIdx])\n");
CutlassPrint(conv2d_decl, " },\n");
CutlassPrint(conv2d_decl, " {alpha, beta}\n");
CutlassPrint(conv2d_decl, " );\n\n");
CutlassPrint(conv2d_decl, " status = reduction_op.initialize(reduction_args, nullptr);\n");
CutlassPrint(conv2d_decl, " status = reduction_op();\n");
CutlassPrint(conv2d_decl, "},\n");
CutlassPrint(conv2d_decl, " {alpha, beta}\n");
CutlassPrint(conv2d_decl, ");\n\n");
CutlassPrint(conv2d_decl, "status = reduction_op.initialize(reduction_args, nullptr);\n");
CutlassPrint(conv2d_decl, "status = reduction_op();\n");
}

return conv2d_decl.str();
Expand Down

0 comments on commit 30df1bd

Please sign in to comment.