Skip to content

Commit

Permalink
fixed for fp16 output
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 30df1bd commit 0bce8f3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n");
CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n");
CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n");
CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = Conv2d::ElementAccumulator;\n");

auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) {
Expand Down Expand Up @@ -382,6 +381,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n");
}

CutlassPrint(conv2d_decl, "using ElementOutput = EpilogueOutputOp::ElementOutput ;\n");

std::string tensor_c_init = "{static_cast<ElementOutput*>(ptr_out), layout_C}";
if (has_residual_block) {
tensor_c_init = "{static_cast<ElementOutput*>(ptr_residual), layout_C}";
Expand Down

0 comments on commit 0bce8f3

Please sign in to comment.