Skip to content

Commit

Permalink
test worked with fp32 output
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 084d5c4 commit 08a6147
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
out_dtype=attrs.out_dtype
)

# infer shape of backward_weight
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
// TODO
const int split_k_slices = 1;
const int split_k_slices = 8;
CutlassPrint(conv2d_decl, "int split_k_slices = " + std::to_string(split_k_slices) + ";\n");

CutlassPrint(
Expand All @@ -329,7 +329,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
"split_k_slices);\n");

const bool use_split_k = split_k_slices > 1;
const std::string split_k_mode = use_split_k > 1 ? "kParallel" : "kSerial";
const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial";
CutlassPrint(conv2d_decl,
"const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::" +
split_k_mode + ";\n");
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att
{out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]});

auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
reporter->Assign(types[2], TensorType(wshape, data->dtype));
reporter->Assign(types[2], TensorType(wshape, param->out_dtype));
return true;
}

Expand Down
10 changes: 4 additions & 6 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,11 @@ def test_conv2d_transpose():


def test_conv2d_backward_weight():
OC = 8
IC = 16
OC = 32
IC = 32
d_shape = (16, IC, 32, 32)
w_shape = (OC, IC, 3, 3)
dtype = "float32"
dtype = "float16"

for strides in [(1, 1), (2, 2)]:
o_shape = (16, OC, 32 // strides[0], 32 // strides[1])
Expand All @@ -845,7 +845,7 @@ def test_conv2d_backward_weight():
o_shape,
padding,
strides,
out_dtype=dtype,
out_dtype="float32",
data_dtype=dtype,
weight_dtype=dtype,
)
Expand All @@ -863,8 +863,6 @@ def test_conv2d_backward_weight():
data_dtype=dtype,
)

# split k


def test_conv2d_bwd():
IC = 16
Expand Down

0 comments on commit 08a6147

Please sign in to comment.