diff --git a/paddle/phi/kernels/onednn/conv_handler.h b/paddle/phi/kernels/onednn/conv_handler.h index 1473cb1b5a2483..86baabf45afc10 100644 --- a/paddle/phi/kernels/onednn/conv_handler.h +++ b/paddle/phi/kernels/onednn/conv_handler.h @@ -180,7 +180,7 @@ class ConvOneDNNHandlerT weights_md = funcs::OneDNNMemDesc( weights_tz, data_type, funcs::OneDNNMemoryFormat::any); } - if (input->dims().size() == 4 && input->dims()[1] == 3) { + if (input->dims().size() == 4 && input->dims()[1] <= 4) { chosen_memory_format = funcs::OneDNNMemoryFormat::nhwc; } const auto dst_md = funcs::OneDNNMemDesc( diff --git a/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc b/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc index 4b6498d07289ec..4dfc4a731bff2c 100644 --- a/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc +++ b/test/cpp/fluid/mkldnn/test_conv_mkldnn_nhwc.cc @@ -108,3 +108,92 @@ TEST(test_conv2d_output, int8) { op->Run(scope, cpu_place); } +TEST(test_conv2d_output, ic1) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + paddle::framework::OpDesc conv2d_op(nullptr); + conv2d_op.SetType("conv2d"); + conv2d_op.SetInput("Input", {"conv2d-X"}); + conv2d_op.SetInput("Filter", {"conv2d-Y"}); + conv2d_op.SetOutput("Output", {"conv2d-Out"}); + + AddVarToScope("conv2d-X", &scope, {1, 1, 224, 224}); + AddVarToScope("conv2d-Y", &scope, {64, 1, 7, 7}); + AddVarToScope("conv2d-Out", &scope, {1, 64, 218, 218}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + conv2d_op.SetAttr("strides", strides); + conv2d_op.SetAttr("paddings", paddings); + conv2d_op.SetAttr("dilations", dilations); + conv2d_op.SetAttr("groups", groups); + conv2d_op.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(conv2d_op); + + op->Run(scope, cpu_place); +} + +TEST(test_conv2d_output, ic2) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + paddle::framework::OpDesc conv2d_op(nullptr); + conv2d_op.SetType("conv2d"); + conv2d_op.SetInput("Input", {"conv2d-X"}); + conv2d_op.SetInput("Filter", {"conv2d-Y"}); + conv2d_op.SetOutput("Output", {"conv2d-Out"}); + + AddVarToScope("conv2d-X", &scope, {1, 2, 224, 224}); + AddVarToScope("conv2d-Y", &scope, {64, 2, 7, 7}); + AddVarToScope("conv2d-Out", &scope, {1, 64, 218, 218}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + conv2d_op.SetAttr("strides", strides); + conv2d_op.SetAttr("paddings", paddings); + conv2d_op.SetAttr("dilations", dilations); + conv2d_op.SetAttr("groups", groups); + conv2d_op.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(conv2d_op); + + op->Run(scope, cpu_place); +} + +TEST(test_conv2d_output, ic4) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace cpu_place; + + paddle::framework::OpDesc conv2d_op(nullptr); + conv2d_op.SetType("conv2d"); + conv2d_op.SetInput("Input", {"conv2d-X"}); + conv2d_op.SetInput("Filter", {"conv2d-Y"}); + conv2d_op.SetOutput("Output", {"conv2d-Out"}); + + AddVarToScope("conv2d-X", &scope, {1, 4, 224, 224}); + AddVarToScope("conv2d-Y", &scope, {64, 4, 7, 7}); + AddVarToScope("conv2d-Out", &scope, {1, 64, 218, 218}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + conv2d_op.SetAttr("strides", strides); + conv2d_op.SetAttr("paddings", paddings); + conv2d_op.SetAttr("dilations", dilations); + conv2d_op.SetAttr("groups", groups); + conv2d_op.SetAttr("use_mkldnn", true); + + auto op = paddle::framework::OpRegistry::CreateOp(conv2d_op); + + op->Run(scope, cpu_place); +}