@@ -296,7 +296,9 @@ class MklInputConversionOp : public OpKernel {
296
296
// implementation.
297
297
TensorShape tf_shape0 = input_shape_0.GetTfShape ();
298
298
TensorShape tf_shape1 = input_shape_1.GetTfShape ();
299
- if (tf_shape0 == tf_shape1) {
299
+ TensorShape tensor_shape0 = input_tensor_0.shape ();
300
+ TensorShape tensor_shape1 = input_tensor_1.shape ();
301
+ if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
300
302
auto input0_md = input_shape_0.GetMklLayout ();
301
303
auto input1_md = input_shape_1.GetMklLayout ();
302
304
@@ -350,7 +352,8 @@ class MklInputConversionOp : public OpKernel {
350
352
}
351
353
352
354
// Sanity check
353
- bool mkl_shapes_are_same = input_shape_0 == input_shape_1;
355
+ bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
356
+ (tensor_shape0 == tensor_shape1));
354
357
if (mkl_shapes_are_same) {
355
358
CHECK (false ) << " MklInputConversionOp: Unexpected: TF shapes are "
356
359
" different but MKL shapes are same" ;
@@ -403,7 +406,8 @@ class MklInputConversionOp : public OpKernel {
403
406
}
404
407
405
408
// Broadcast is needed if the shapes are not the same
406
- if (mkl_shape->GetTfShape ().num_elements () == tf_tensor->shape ().num_elements () ) {
409
+ if (mkl_shape->GetTfShape ().num_elements ()
410
+ == tf_tensor->shape ().num_elements () ) {
407
411
// Both shapes are same, convert the TF input to MKL
408
412
VLOG (1 ) << " MklInputConversionOp: No broadcast needed." ;
409
413
VLOG (1 ) << " MklInputConversionOp: Converting input " << tf_tensor_index
@@ -437,16 +441,17 @@ class MklInputConversionOp : public OpKernel {
437
441
bool reordered = tf_input.CheckReorderToOpMem (
438
442
memory::primitive_desc (output_mkl_md, cpu_engine),
439
443
tensor_out, &net);
440
- if (!reordered) {
444
+
445
+ if (!reordered) {
441
446
// This is the case that the TF tensor has the same shape and format of
442
447
// mkl tensor. However, tf_tensor can not be simply forwarded to the
443
448
// output tensor since mkl data tensor is always one dimensional tensor.
444
449
// Tensor::CopyFrom shares the buffer of the other tensor while set its
445
450
// shape to the other tensor.
446
451
CHECK (tensor_out->CopyFrom (*tf_tensor, tensor_out->shape ()));
447
- }
448
- else
452
+ } else {
449
453
stream (stream::kind::eager).submit (net).wait ();
454
+ }
450
455
451
456
// -- The tensor in MKL format passes through --
452
457
ForwardMklTensorInToOut (context, mkl_tensor_index, mkl_tensor_index);
0 commit comments