From 6739ada98125d85824c47c4ff367539592c3f7fa Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Fri, 28 Feb 2020 00:46:34 -0800 Subject: [PATCH] Avoid updating keras lstm function signature PiperOrigin-RevId: 297787450 Change-Id: Id6737125bcebd4c204f60d91c3c005fd5c38f326 --- .../tests/prepare-composite-functions-tf.mlir | 48 ++++++++++++------- .../compiler/mlir/lite/utils/lstm_utils.cc | 33 +++++++++++-- 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 0dcce6bf4e8280..8a2d702578c38d 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -154,18 +154,18 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3 // ----- module { -func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor - %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -181,8 +181,12 @@ func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit // CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( { -// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x?x10xf32> -// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32> +// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_21]], [[VAL_25:%.*]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor, tensor, tensor, tensor, tensor // CHECK: } } @@ -200,7 +204,7 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { +// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -221,25 +225,29 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te // CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> // CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> // CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32> -// CHECK: return [[VAL_24]] : tensor<8x8x10xf32> +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_26]], [[VAL_24]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor // CHECK: } } // ----- module { -func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor - %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32> // CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor, tensor<1xi32>) -> tensor // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -257,8 +265,12 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_21:%.*]] = constant unit // CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( { -// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x?x10xf32> -// CHECK: return [[VAL_23:%.*]] : tensor<8x?x10xf32> +// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_23]], [[VAL_27:%.*]], [[VAL_24]], [[VAL_25]], [[VAL_26]] : tensor, tensor, tensor, tensor, tensor // CHECK: } } @@ -277,7 +289,7 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3 return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor } -// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32> @@ -300,7 +312,11 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3 // CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> // CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> // CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32> -// CHECK: return [[VAL_26]] : tensor<8x8x10xf32> +// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_30:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_31:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_28]], [[VAL_26]], [[VAL_29]], [[VAL_30]], [[VAL_31]] : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor // CHECK: } } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 400d504b8f1072..54828d79d01972 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -57,6 +57,14 @@ Value CreateF32SplatConst(OpBuilder* builder, ArrayRef shape, return builder->create(location, type, attr); } +Value CreatTfF32ConstOp(OpBuilder* builder, ArrayRef shape, float val, + mlir::Location location) { + auto type = RankedTensorType::get(shape, builder->getF32Type()); + auto ele_type = RankedTensorType::get({1}, builder->getF32Type()); + auto attr = DenseElementsAttr::get(ele_type, val); + return builder->create(location, type, attr); +} + Value CreateI64DenseConst(OpBuilder* builder, ArrayRef shape, ArrayRef values, mlir::Location location) { auto type = RankedTensorType::get(static_cast(shape.size()), @@ -601,6 +609,9 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { Value recurrent_kernel = func_op.getArgument(4); Value bias = func_op.getArgument(5); + // The func op should have 5 outputs. + if (func_op.getNumResults() != 5) return failure(); + // TFL lstm only supports time-majored inputs, so if it's not time-majored, // we will transpose the inputs and outputs. auto time_major_attr = func_op.getAttrOfType("tf.time_major"); @@ -674,11 +685,8 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { &bias_array))) return failure(); - // Update the function signature: - UpdateFuncSignature(batch, time, n_output, &func_op); - // Build the lstm op. - SmallVector output_shape = {batch, time, n_output}; + SmallVector output_shape = {time, batch, n_output}; auto result_type = mlir::RankedTensorType::get( output_shape, input.getType().cast().getElementType()); @@ -718,7 +726,22 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { final_output = Transpose(builder, final_output, perm, result_type, func_op.getLoc()); } - builder->create(func_op.getLoc(), final_output); + + SmallVector outputs; + + for (int i = 0; i < 5; ++i) { + if (i == 1) { + // only this one is the real output. + outputs.push_back(final_output); + } else { + auto result_type = + func_op.getCallableResults()[i].dyn_cast(); + outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f, + func_op.getLoc())); + } + } + + builder->create(func_op.getLoc(), outputs); return success(); }