Skip to content

Commit d3b4075

Browse files
committed
Merge branch 'primreuse_relu' into latest_optimizations
2 parents 200a599 + ffc12e1 commit d3b4075

File tree

2 files changed

+510
-127
lines changed

2 files changed

+510
-127
lines changed

tensorflow/core/kernels/mkl_input_conversion_op.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ class MklInputConversionOp : public OpKernel {
296296
// implementation.
297297
TensorShape tf_shape0 = input_shape_0.GetTfShape();
298298
TensorShape tf_shape1 = input_shape_1.GetTfShape();
299-
if (tf_shape0 == tf_shape1) {
299+
TensorShape tensor_shape0 = input_tensor_0.shape();
300+
TensorShape tensor_shape1 = input_tensor_1.shape();
301+
if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
300302
auto input0_md = input_shape_0.GetMklLayout();
301303
auto input1_md = input_shape_1.GetMklLayout();
302304

@@ -350,7 +352,8 @@ class MklInputConversionOp : public OpKernel {
350352
}
351353

352354
// Sanity check
353-
bool mkl_shapes_are_same = input_shape_0 == input_shape_1;
355+
bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
356+
(tensor_shape0 == tensor_shape1));
354357
if (mkl_shapes_are_same) {
355358
CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
356359
"different but MKL shapes are same";
@@ -403,7 +406,8 @@ class MklInputConversionOp : public OpKernel {
403406
}
404407

405408
// Broadcast is needed if the shapes are not the same
406-
if (mkl_shape->GetTfShape().num_elements() == tf_tensor->shape().num_elements() ) {
409+
if (mkl_shape->GetTfShape().num_elements()
410+
== tf_tensor->shape().num_elements() ) {
407411
// Both shapes are same, convert the TF input to MKL
408412
VLOG(1) << "MklInputConversionOp: No broadcast needed.";
409413
VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
@@ -437,16 +441,17 @@ class MklInputConversionOp : public OpKernel {
437441
bool reordered = tf_input.CheckReorderToOpMem(
438442
memory::primitive_desc(output_mkl_md, cpu_engine),
439443
tensor_out, &net);
440-
if(!reordered) {
444+
445+
if (!reordered) {
441446
// This is the case that the TF tensor has the same shape and format of
442447
// mkl tensor. However, tf_tensor can not be simply forwarded to the
443448
// output tensor since mkl data tensor is always one dimensional tensor.
444449
// Tensor::CopyFrom shares the buffer of the other tensor while set its
445450
// shape to the other tensor.
446451
CHECK(tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()));
447-
}
448-
else
452+
} else {
449453
stream(stream::kind::eager).submit(net).wait();
454+
}
450455

451456
// -- The tensor in MKL format passes through --
452457
ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);

0 commit comments

Comments
 (0)