Skip to content

Commit

Permalink
fix compile error for fprop
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 31f2543 commit 084d5c4
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,14 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
// TODO
const int split_k_slices = 8;
const int split_k_slices = 1;
CutlassPrint(conv2d_decl, "int split_k_slices = " + std::to_string(split_k_slices) + ";\n");

CutlassPrint(
conv2d_decl,
"cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, "
"stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, split_k_slices);\n");
"stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, "
"split_k_slices);\n");

const bool use_split_k = split_k_slices > 1;
const std::string split_k_mode = use_split_k > 1 ? "kParallel" : "kSerial";
Expand Down Expand Up @@ -403,18 +404,19 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " tensor_d,\n");

if (has_residual_block) {
ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion";
CutlassPrint(conv2d_decl, "{alpha, beta},\n");
CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices
CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
} else if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n},\n");
CutlassPrint(conv2d_decl, " {alpha},\n");
CutlassPrint(conv2d_decl, "split_k_mode\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta},\n");
CutlassPrint(conv2d_decl, "split_k_mode\n};\n");
}

CutlassPrint(conv2d_decl, "split_k_mode\n};\n");

CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand All @@ -429,18 +431,21 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "arguments.ref_D.reset(reinterpret_cast<ElementOutput*>(workspace.get())); \n");
CutlassPrint(conv2d_decl, "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n");
CutlassPrint(conv2d_decl,
"\narguments.ref_D.reset(reinterpret_cast<ElementOutput*>(workspace.get())); \n");
CutlassPrint(
conv2d_decl,
"arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n");
CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments, workspace.get()); \n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n");
}

// Launch initialized CUTLASS kernel
CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "using EpilogueOutputOp = Conv2d::EpilogueOutputOp;\n");
CutlassPrint(conv2d_decl, "\nusing EpilogueOutputOp = Conv2d::EpilogueOutputOp;\n");
CutlassPrint(conv2d_decl, "using ReductionOp = cutlass::reduction::thread::ReduceAdd<\n");
CutlassPrint(conv2d_decl, " Conv2d::ElementAccumulator,\n");
CutlassPrint(conv2d_decl, " typename EpilogueOutputOp::ElementAccumulator,\n");
Expand Down Expand Up @@ -472,7 +477,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
" cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, "
"problem_size),\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, " reinterpret_cast<Conv2d::ElementAccumulator*> (workspace.get()),\n");
CutlassPrint(conv2d_decl,
" reinterpret_cast<Conv2d::ElementAccumulator*> (workspace.get()),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
Expand Down

0 comments on commit 084d5c4

Please sign in to comment.