Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,21 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

#ifdef PADDLE_WITH_MKLDNN
// extra checking if attr "use_mkldnn" exist is needed because
// test_reverse_op is calling concat_grad kernel without setting
// "use_mkldnn" to any value
if (ctx.HasAttr("use_mkldnn") &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
71 changes: 71 additions & 0 deletions paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace operators {

using framework::DataLayout;
using framework::Tensor;
using framework::LoDTensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::concat;
Expand Down Expand Up @@ -149,6 +150,72 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
}
};

template <typename T>
class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));

const auto x = ctx.MultiInput<LoDTensor>("X");
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto dx = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));

for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}

int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<Tensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
}

auto dout_vec_dims = framework::vectorize(dout->dims());

axis = ComputeAxis(axis, dout_vec_dims.size());

std::vector<int64_t> offset(dout_vec_dims.size(), 0);

mkldnn::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout->type());
platform::ReorderMKLDNNHandler reorder_handler(dout_vec_dims, dout->type(),
dout_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));

for (size_t i = 0; i < dx.size(); ++i) {
if (out_var_names[i] != framework::kEmptyVarName &&
dx[i]->numel() != 0UL) {
auto dx_vec_dims = framework::vectorize(dx[i]->dims());
auto slice_mem_p = reorder_handler.AcquireSubmemory(
dx_vec_dims, offset, reorder_src_memory_p);

auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx[i], dx_vec_dims, dout->format(), ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);

reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);

offset[axis] += dx[i]->dims()[axis];

dx[i]->set_layout(framework::DataLayout::kMKLDNN);
dx[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
}
astream.wait();
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -159,3 +226,7 @@ REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConcatMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConcatMKLDNNOpKernel<int8_t>,
ops::ConcatMKLDNNOpKernel<uint8_t>);

REGISTER_OP_KERNEL(concat_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConcatGradMKLDNNOpKernel<float>,
ops::ConcatGradMKLDNNOpKernel<paddle::platform::bfloat16>);
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,28 @@ def setUp(self):
'mkldnn_data_type': self.mkldnn_data_type
}

self.sections = [self.x0.shape[self.axis]] * 2
self.sections[1] += self.x1.shape[self.axis]

self.output = np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis).astype(np.uint16)
self.outputs = {'Out': self.output}

def calculate_grads(self):
self.dout = self.outputs['Out']
self.dxs = np.split(self.dout, self.sections, self.axis)

def test_check_output(self):
self.check_output_with_place(core.CPUPlace())

def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["x0", "x1", "x2"],
"Out",
user_defined_grads=[self.dxs[0], self.dxs[1], self.dxs[2]],
user_defined_grad_outputs=[self.dout])

# --------------------test concat bf16 in with axis 0--------------------

def init_test_data(self):
Expand All @@ -61,9 +76,9 @@ def init_axis(self):
self.axis = 0

def init_shape(self):
self.x0_shape = [2, 2, 1, 2]
self.x1_shape = [1, 2, 1, 2]
self.x2_shape = [3, 2, 1, 2]
self.x0_shape = [6, 2, 4, 3]
self.x1_shape = [7, 2, 4, 3]
self.x2_shape = [8, 2, 4, 3]


# --------------------test concat bf16 in with axis 1--------------------
Expand All @@ -74,9 +89,9 @@ def init_axis(self):
self.axis = 1

def init_shape(self):
self.x0_shape = [1, 1, 5, 5]
self.x1_shape = [1, 2, 5, 5]
self.x2_shape = [1, 3, 5, 5]
self.x0_shape = [1, 4, 5, 5]
self.x1_shape = [1, 8, 5, 5]
self.x2_shape = [1, 6, 5, 5]


# --------------------test concat bf16 in with axis 2--------------------
Expand Down
114 changes: 63 additions & 51 deletions python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,78 +15,90 @@
from __future__ import print_function

import unittest
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3, TestConcatOp4
import numpy as np
import struct


class TestMKLDNNConcatOp(TestConcatOp):
def setUp(self):
super(TestMKLDNNConcatOp, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))

def test_check_grad(self):
pass

def init_kernel_type(self):
self.use_mkldnn = True
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static


class TestMKLDNNConcatOp2(TestConcatOp2):
class TestConcatAxis0OneDNNOp(OpTest):
def setUp(self):
super(TestMKLDNNConcatOp2, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
self.op_type = "concat"
self.mkldnn_data_type = "float32"
self.init_axis()
self.init_shape()
self.init_test_data()
self.configure_datatype()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {
'axis': self.axis,
'use_mkldnn': True,
'mkldnn_data_type': self.mkldnn_data_type
}

self.output = np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis).astype(self.dtype)

self.outputs = {'Out': self.output}

def configure_datatype(self):
self.mkldnn_data_type = "float32"
self.dtype = np.float32

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
self.check_output_with_place(core.CPUPlace())

def test_check_grad(self):
pass
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')

def init_kernel_type(self):
self.use_mkldnn = True
def init_test_data(self):
self.x0 = np.random.random(self.x0_shape).astype(np.float32)
self.x1 = np.random.random(self.x1_shape).astype(np.float32)
self.x2 = np.random.random(self.x2_shape).astype(np.float32)

def init_axis(self):
self.axis = 0

class TestMKLDNNConcatOp3(TestConcatOp3):
def setUp(self):
super(TestMKLDNNConcatOp3, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def init_shape(self):
self.x0_shape = [2, 2, 1, 50]
self.x1_shape = [1, 2, 1, 50]
self.x2_shape = [3, 2, 1, 50]

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))

def test_check_grad(self):
pass
class TestConcatAxis1OneDNNOp(TestConcatAxis0OneDNNOp):
def init_axis(self):
self.axis = 1

def init_kernel_type(self):
self.use_mkldnn = True
def init_shape(self):
self.x0_shape = [1, 1, 5, 50]
self.x1_shape = [1, 2, 5, 50]
self.x2_shape = [1, 3, 5, 50]


class TestMKLDNNConcatOp4(TestConcatOp4):
def setUp(self):
super(TestMKLDNNConcatOp4, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
class TestConcatAxis2OneDNNOp(TestConcatAxis0OneDNNOp):
def init_axis(self):
self.axis = 2

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def init_shape(self):
self.x0_shape = [2, 3, 4, 50]
self.x1_shape = [2, 3, 5, 50]
self.x2_shape = [2, 3, 6, 50]

def test_check_grad(self):
pass

def init_kernel_type(self):
self.use_mkldnn = True
class TestConcatAxis3OneDNNOp(TestConcatAxis0OneDNNOp):
def init_axis(self):
self.axis = 3

def init_shape(self):
self.x0_shape = [5, 3, 5, 5]
self.x1_shape = [5, 3, 5, 6]
self.x2_shape = [5, 3, 5, 7]


if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_concat_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
import paddle
Expand Down