Skip to content

Primreuse pooling #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 134 additions & 141 deletions tensorflow/core/kernels/mkl_avgpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,22 +442,27 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {

void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
const Tensor& input_tensor =
MklGetInput(context, this->kInputTensorIndexInput);
const Tensor& input_tensor = MklGetInput(context,
this->kInputTensorIndexInput);
MklDnnShape dnn_shape_input;
GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
this->SanityCheckInput(context, input_tensor, dnn_shape_input);
if (!context->status().ok()) return;

MklDnnData<T> dnn_data_input(&cpu_engine);
MklDnnData<T> dnn_data_output(&cpu_engine);
MklDnnData<T> dnn_data_input(&cpu_engine_);

// initialize variables for the pooling op
MklPoolParameters pool_params;
// Get the input tensor and initialize the pooling parameters
this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
&dnn_data_input);
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params,
dnn_shape_input, input_tensor_shape);
// Get the input memory descriptor
memory::desc input_md = dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetMklLayout()
: memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
OP_REQUIRES_OK(context, context->status());

// Declare output tensor
Expand Down Expand Up @@ -487,45 +492,58 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
return;
}

// If input is in Mkl layout, then just get the memory format from it
// directly, instead of using input data_format to AvgPool.
if (dnn_shape_input.IsMklTensor()) {
dnn_data_output.SetUsrMem(
output_dims_mkl_order,
static_cast<memory::format>(
dnn_data_input.GetUsrMemDesc().data.format));
// Get src/filter/stride/padding information
memory::dims src_dims = dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
this->data_format_tf_);

memory::dims filter_dims = memory::dims({pool_params.window_rows,
pool_params.window_cols});
memory::dims strides = memory::dims(
{pool_params.row_stride, pool_params.col_stride});
memory::dims padding_left = memory::dims(
{static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)});
memory::dims padding_right = memory::dims(
{static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)});

// Get an average pooling primitive from the op pool
MklPoolingFwdPrimitive<T> *pooling_fwd = nullptr;
MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims,
strides, padding_left, padding_right,
algorithm::pooling_avg_exclude_padding);
pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);

// allocate output tensor
this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
output_dims_mkl_order, this->data_format_mkldnn_, &output_tensor);
CHECK_NOTNULL(output_tensor);

OP_REQUIRES_OK(context, context->status());

// check whether we need to reorder src
std::vector<primitive> net;
T* src_data = nullptr;
if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) {
dnn_data_input.SetUsrMem(input_md, &input_tensor);
auto src_target_primitive_desc = memory::primitive_desc({{src_dims},
MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()}, cpu_engine_);
dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc, &net);
src_data = static_cast<T*>(
dnn_data_input.GetOpMem().get_data_handle());
} else {
dnn_data_output.SetUsrMem(output_dims_mkl_order,
this->data_format_mkldnn_);
src_data = static_cast<T*>(const_cast<T*>(
input_tensor.flat<T>().data()));
}
stream(stream::kind::eager).submit(net).wait();

// describe the memory layout
dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);

// 3. create a pooling primitive descriptor
auto pool_desc = pooling_forward::desc(
prop_kind::forward, algorithm::pooling_avg_exclude_padding,
dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
memory::dims({pool_params.row_stride, pool_params.col_stride}),
memory::dims({pool_params.window_rows, pool_params.window_cols}),
memory::dims({static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)}),
memory::dims({static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)}),
TFPaddingToMklDnnPadding(this->padding_));
auto pool_prim_desc =
pooling_forward::primitive_desc(pool_desc, cpu_engine);

this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
CHECK_NOTNULL(output_tensor);

OP_REQUIRES_OK(context, context->status());
dnn_data_output.SetUsrMemDataHandle(output_tensor);
T* dst_data = static_cast<T*>(
const_cast<T*>(output_tensor->flat<T>().data()));

this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input,
&dnn_data_output);
// execute pooling
pooling_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
Expand All @@ -535,9 +553,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
errors::Aborted("Operation received an exception:", error_msg));
}
} // Compute
}; // MklAvgPoolingOp

//-----------------------------------------------------------------------------
private:
engine cpu_engine_ = engine(engine::cpu, 0);
}; // MklAvgPoolingOp

template <class Device, class T>
class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
Expand All @@ -547,125 +566,99 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {

void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape;
const Tensor& tensor_in_shape =
const Tensor& orig_input_tensor =
MklGetInput(context, kInputTensorIndexInputShape);
const Tensor& input_gradient_tensor =
const Tensor& grad_tensor =
MklGetInput(context, kInputTensorIndexInputGradient);

MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
GetMklShape(context, kInputTensorIndexInputShape,
&original_input_mkl_shape);
&orig_input_mkl_shape);
GetMklShape(context, kInputTensorIndexInputGradient,
&input_gradient_mkl_shape);

SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor,
original_input_mkl_shape, input_gradient_mkl_shape);
&grad_mkl_shape);
if (!context->status().ok()) return;

// Used to allocate output_diff_src/diff_src
// and create pool_fwd mdm desc
// 0. Input("orig_input_shape: int32") //NOT a T Tensor!
// 1. Input("grad: T")

MklDnnData<T> input_gradient_diff_dst(&cpu_engine);
MklDnnData<T> output_diff_src(&cpu_engine);
Tensor* output_tensor_diff_src = nullptr;
TensorShape original_input_shape;
MklDnnData<T> grad_dnn_data(&cpu_engine_);
MklPoolParameters pool_params;
memory::dims output_dims_mkl_order, original_input_dims_nchw;
// Configure the original input memory descriptor
memory::desc original_input_md = ConfigureOriginalInput(
context, tensor_in_shape, original_input_mkl_shape,
&original_input_dims_nchw, &pool_params, &original_input_shape);

// configure the original output memory descriptor
// by definition, the shape of the original output is the same
// as the shape of the gradient diff_dst
memory::desc original_output_md = this->ConfigureOriginalOutput(
pool_params, input_gradient_mkl_shape, output_dims_mkl_order);

memory::desc target_diff_dst_md = this->ConfigureInputGradient(
input_gradient_mkl_shape, input_gradient_tensor,
&input_gradient_diff_dst, original_output_md);
// The shape of the output diff src needs to be the same shape as the
// original input. But we will set its format to be same as the format of
// input gradient. We won't use format of original input since it will
// always be in Tensorflow layout (given that AvgPoolGrad gets shape of
// the input rather than actual input).
output_diff_src.SetUsrMem(
original_input_dims_nchw,
static_cast<memory::format>(target_diff_dst_md.data.format));

// Create the forward pooling primitive descriptor so we can reference it
// in the backward pooling primitive descriptor
auto pool_fwd_desc = pooling_forward::desc(
prop_kind::forward, algorithm::pooling_avg_exclude_padding,
original_input_md, original_output_md,
memory::dims({pool_params.row_stride, pool_params.col_stride}),
memory::dims({pool_params.window_rows, pool_params.window_cols}),
memory::dims({static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)}),
memory::dims({static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)}),
TFPaddingToMklDnnPadding(this->padding_));
auto pool_fwd_prim_desc =
pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);

auto pool_bkwd_desc = pooling_backward::desc(
algorithm::pooling_avg_exclude_padding,
output_diff_src.GetUsrMemDesc(), target_diff_dst_md,
memory::dims({pool_params.row_stride, pool_params.col_stride}),
memory::dims({pool_params.window_rows, pool_params.window_cols}),
memory::dims({static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)}),
memory::dims({static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)}),
TFPaddingToMklDnnPadding(this->padding_));
auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
this->AllocateOutputTensor(
context, pool_bkwd_prim_desc, original_input_dims_nchw,
this->data_format_mkldnn_, &output_tensor_diff_src);

output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src);

this->PrepareAndExecuteNet(
pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src,
memory::primitive_desc(target_diff_dst_md, cpu_engine));
auto shape_vec = orig_input_tensor.vec<int32>();
TensorShape orig_input_shape;
for (int i = 0; i < orig_input_tensor.NumElements(); i++) {
orig_input_shape.AddDim(shape_vec(i));
}
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims = memory::dims(
{pool_params.window_rows, pool_params.window_cols});
memory::dims strides = memory::dims(
{pool_params.row_stride, pool_params.col_stride});
memory::dims padding_left = memory::dims(
{static_cast<int>(pool_params.pad_top),
static_cast<int>(pool_params.pad_left)});
memory::dims padding_right = memory::dims(
{static_cast<int>(pool_params.pad_bottom),
static_cast<int>(pool_params.pad_right)});
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(orig_input_shape, this->data_format_tf_);
memory::dims diff_dst_dims = grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);

MklPoolingParams bwdParams(orig_input_dims_mkl_order,
output_dims_mkl_order, filter_dims, strides,
padding_left, padding_right, algorithm::pooling_avg_exclude_padding);
MklPoolingBwdPrimitive<T> *pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);

Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
// get diff_dst memory::desc
memory::desc diff_dst_md = grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);
// Check whether we need to reorder diff_dst
T* diff_dst_data = nullptr;
std::vector<primitive> net;
if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) {
auto target_diff_dst = memory::primitive_desc({{diff_dst_dims},
MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()}, cpu_engine_);
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
grad_dnn_data.CheckReorderToOpMem(target_diff_dst, &net);
diff_dst_data = static_cast<T*>(
grad_dnn_data.GetOpMem().get_data_handle());
} else {
diff_dst_data = static_cast<T*>(const_cast<T*>(
grad_tensor.flat<T>().data()));
}
stream(stream::kind::eager).submit(net).wait();
T* diff_src_data = static_cast<T*>(
const_cast<T*>(output_tensor->flat<T>().data()));

// execute pooling op
pooling_bwd->Execute(diff_dst_data, diff_src_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
error_msg));
}
} // Compute
}

private:
// 0. Input("orig_input_shape: int32")
// 1. Input("grad: T")
const int kInputTensorIndexInputShape = 0;
const int kInputTensorIndexInputGradient = 1;

memory::desc ConfigureOriginalInput(
OpKernelContext* context, const Tensor& tensor_original_input_shape,
const MklDnnShape& original_input_mkl_shape,
memory::dims* original_input_dims_mkl_order,
MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
CHECK_NOTNULL(original_input_dims_mkl_order);
CHECK_NOTNULL(pool_params);
CHECK_NOTNULL(input_tensor_shape);
// For AvgPoolGrad, we only get the size of the original input because
// The original data is irrelvant.
auto shape_vec = tensor_original_input_shape.vec<int32>();
for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) {
input_tensor_shape->AddDim(shape_vec(i));
}

return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
context, tensor_original_input_shape, original_input_mkl_shape,
original_input_dims_mkl_order, pool_params, *input_tensor_shape);
}
engine cpu_engine_ = engine(engine::cpu, 0);

void SanityCheckInputs(OpKernelContext* context,
const Tensor& tensor_in_shape,
Expand Down
Loading