Skip to content

Commit

Permalink
Add bf16 data type support to oneDNN bilinear_interp kernel (#46770)
Browse files Browse the repository at this point in the history
* Enable bf16 in oneDNN bilinear_interp kernel

* Fix bilinear_interp_v2 not enabled in models

* Remove unnecessary checks
  • Loading branch information
piotrekobi authored Nov 16, 2022
1 parent e23dfed commit 8e6315e
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 20 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2789,7 +2789,8 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"cast",
std::unordered_set<std::string>({"bilinear_interp_v2",
"cast",
"clip",
"concat",
"conv2d",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ int CPUBfloat16PlacementPass::SetMkldnnDataType(ir::Graph* graph) const {
// Only float input can be converted to bfloat16
if (op_in->Var()->GetDataType() != proto::VarType::FP32) return;

if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
if (platform::HasOpINT8DataType(op->Op()) == false) {
VLOG(4) << "--- marked " << op->Op()->Type()
<< " operator to bfloat16 ";
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
"interpolate_mkldnn_pass",
// TODO(baoachun): Need to support 5-dimensional input.
// "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/kernels/onednn/interpolate_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,12 @@ void NearestInterpKernel(
}
} // namespace phi

PD_REGISTER_KERNEL(
bilinear_interp, OneDNN, ONEDNN, phi::BilinearInterpKernel, float) {}
PD_REGISTER_KERNEL(bilinear_interp,
OneDNN,
ONEDNN,
phi::BilinearInterpKernel,
float,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(nearest_interp,
OneDNN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import unittest
import numpy as np
import math
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci


def bilinear_interp_mkldnn_np(
def bilinear_interp_onednn_np(
input, out_h, out_w, out_size=None, actual_shape=None, data_layout='NCHW'
):
"""bilinear interpolation implement in shape [N, C, H, W]"""
Expand Down Expand Up @@ -65,17 +65,21 @@ def bilinear_interp_mkldnn_np(


@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
class TestBilinearInterpMKLDNNOp(OpTest):
class TestBilinearInterpOneDNNOp(OpTest):
def init_test_case(self):
pass

def init_data_type(self):
pass

def setUp(self):
self.op_type = "bilinear_interp_v2"
self.interp_method = 'bilinear'
self._cpu_only = True
self.use_mkldnn = True
self.use_onednn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
Expand All @@ -84,8 +88,12 @@ def setUp(self):
self.actual_shape = None

self.init_test_case()
self.init_data_type()

input_np = np.random.random(self.input_shape).astype(self.dtype)
if self.dtype == np.uint16:
input_np = convert_float_to_uint16(input_np)

input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
Expand Down Expand Up @@ -114,7 +122,7 @@ def setUp(self):
out_h = self.out_h
out_w = self.out_w

output_np = bilinear_interp_mkldnn_np(
output_np = bilinear_interp_onednn_np(
input_np,
out_h,
out_w,
Expand All @@ -137,15 +145,15 @@ def setUp(self):
'out_w': self.out_w,
'scale': self.scale,
'data_layout': self.data_layout,
'use_mkldnn': self.use_mkldnn,
'use_mkldnn': self.use_onednn,
}
self.outputs = {'Out': output_np}

def test_check_output(self):
self.check_output(check_dygraph=False)


class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp):
class TestBilinearInterpOpOneDNNNHWC(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [3, 2, 32, 16]
self.out_h = 27
Expand All @@ -154,22 +162,22 @@ def init_test_case(self):
self.data_layout = 'NHWC'


class TestBilinearNeighborInterpMKLDNNCase2(TestBilinearInterpMKLDNNOp):
class TestBilinearNeighborInterpOneDNNCase2(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12


class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp):
class TestBilinearNeighborInterpOneDNNCase3(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 128
self.scale = [0.1, 0.05]


class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp):
class TestBilinearNeighborInterpOneDNNCase4(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
Expand All @@ -178,15 +186,15 @@ def init_test_case(self):
self.out_size = np.array([65, 129]).astype("int32")


class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp):
class TestBilinearNeighborInterpOneDNNCase5(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([13, 13]).astype("int32")


class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp):
class TestBilinearNeighborInterpOneDNNCase6(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
Expand All @@ -195,7 +203,7 @@ def init_test_case(self):
self.out_size = np.array([65, 129]).astype("int32")


class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp):
class TestBilinearNeighborInterpOneDNNSame(TestBilinearInterpOneDNNOp):
def init_test_case(self):
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
Expand All @@ -204,6 +212,24 @@ def init_test_case(self):
self.out_size = np.array([65, 129]).astype("int32")


def create_test_class(parent):
class TestBf16Case(parent):
def init_data_type(self):
self.dtype = np.uint16

TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
globals()[TestBf16Case.__name__] = TestBf16Case


create_test_class(TestBilinearInterpOneDNNOp)
create_test_class(TestBilinearInterpOpOneDNNNHWC)
create_test_class(TestBilinearNeighborInterpOneDNNCase2)
create_test_class(TestBilinearNeighborInterpOneDNNCase3)
create_test_class(TestBilinearNeighborInterpOneDNNCase4)
create_test_class(TestBilinearNeighborInterpOneDNNCase5)
create_test_class(TestBilinearNeighborInterpOneDNNCase6)
create_test_class(TestBilinearNeighborInterpOneDNNSame)

if __name__ == "__main__":
from paddle import enable_static

Expand Down

0 comments on commit 8e6315e

Please sign in to comment.