Skip to content

Commit

Permalink
Fuse keras lstm conditionally,due to Keras LSTM & TFL fused LSTM diff…
Browse files Browse the repository at this point in the history
…erence, we only fuse keras lstm only if the full sequences are consumed.

PiperOrigin-RevId: 2979862
Change-Id: I65e3eb71617249471122d0efea9f6a05f3234302
  • Loading branch information
renjie-liu authored and tensorflower-gardener committed Feb 29, 2020
1 parent 59c7725 commit 03d5fe9
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,82 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3

}

// -----

module {
func @inference_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
%2 = "tf.Add"(%0, %1#1) : (tensor<f32>, tensor<?x8x10xf32>) -> tensor<?x8x10xf32>
return
}

func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}

// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_19:%.*]] = constant unit
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_21]], [[VAL_25:%.*]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }

}

// -----

module {
func @inference_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_cannot_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
%2 = "tf.Add"(%0, %1#2) : (tensor<f32>, tensor<?x10xf32>) -> tensor<?x10xf32>
return
}

func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}

// CHECK: func @inference_standard_lstm_time_major_cannot_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = "tf.BatchMatMulV2"([[VAL_0]], [[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
// CHECK: [[VAL_7:%.*]] = "tf.Add"([[VAL_6]], [[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
// CHECK: [[VAL_8:%.*]] = "tf.BatchMatMulV2"([[VAL_7]], [[VAL_4]]) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
// CHECK: [[VAL_9:%.*]] = "tf.Add"([[VAL_8]], [[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK: [[VAL_10:%.*]] = "tf.Add"([[VAL_8]], [[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK: [[VAL_11:%.*]] = "tf.Add"([[VAL_1]], [[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/lite/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(

// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
// pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareCompositeFunctionsPass();
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();

// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
Expand All @@ -35,10 +36,12 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

// NOLINTNEXTLINE

Expand Down Expand Up @@ -91,14 +94,14 @@ class ConvertEmbeddedLookupFunc {
// body with the corresponding fused TFLite op. The replacement need not always
// be a fused op, though that is the primary use case.
class PrepareCompositeFunctionsPass
: public FunctionPass<PrepareCompositeFunctionsPass> {
: public ModulePass<PrepareCompositeFunctionsPass> {
public:
explicit PrepareCompositeFunctionsPass() {}

private:
void ConvertTFImplements(FuncOp func, StringAttr attr);
void ConvertTFAPIImplements(FuncOp func, StringAttr attr);
void runOnFunction() override;
void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
void runOnModule() override;
};

void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
Expand Down Expand Up @@ -131,14 +134,54 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
}
}

LogicalResult CheckOutputConsumer(
Operation* call_op, int expected_num_outputs,
llvm::DenseSet<int> expected_consumer_indices) {
if (call_op->getNumResults() != expected_num_outputs) return failure();

for (int i = 0; i < expected_num_outputs; ++i) {
auto it = expected_consumer_indices.find(i);
if (it != expected_consumer_indices.end()) {
// Expected consumer.
if (call_op->getResult(i).use_empty()) return failure();
} else {
// Unexpected consumer.
if (!call_op->getResult(i).use_empty()) return failure();
}
}
return success();
}

LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
bool check_failed = false;
for (auto func : module.getOps<FuncOp>()) {
func.walk([&](Operation* op) {
auto call_op = dyn_cast_or_null<CallOpInterface>(op);
if (call_op && op->getAttrOfType<SymbolRefAttr>("f").getRootReference() ==
lstm_func.getName()) {
// Keras LSTM have 5 outputs.
// We should make sure only the second output is consumed.
if (failed(CheckOutputConsumer(call_op, 5, {1}))) check_failed = true;
}
});
}

if (check_failed) return failure();
return success();
}

void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
StringAttr attr) {
StringAttr attr,
ModuleOp module) {
// Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
// TODO(b/147436982): we need to make sure that only the
// outputs(full sequence) is used, not the last_output, not the new_states.
// We will discard everything except the outputs.
// And the outputs is in the shape of [batch, time, units].
if (attr.getValue().startswith("lstm_")) {
// Check if the keras lstm can be fused, if not, we just don't do anything.
if (failed(CheckFusableKerasLstm(func, module))) return;

func.eraseBody();
func.addEntryBlock();

Expand All @@ -148,26 +191,29 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
}
}

void PrepareCompositeFunctionsPass::runOnFunction() {
auto func = getFunction();
// We have two kinds of implements:
// 1) tf._implements.
// 2) tf.api_implements.
// We need to handle them separately.
auto tf_implements_attr = func.getAttrOfType<StringAttr>(kTFImplements);
if (tf_implements_attr) {
ConvertTFImplements(func, tf_implements_attr);
} else {
void PrepareCompositeFunctionsPass::runOnModule() {
auto module = getModule();
for (auto func : module.getOps<FuncOp>()) {
// We have two kinds of implements:
// 1) tf._implements.
// 2) tf.api_implements.
// We need to handle them separately.
auto tf_implements_attr = func.getAttrOfType<StringAttr>(kTFImplements);
if (tf_implements_attr) {
ConvertTFImplements(func, tf_implements_attr);
}

auto tf_api_implements_attr =
func.getAttrOfType<StringAttr>(kTFAPIImplements);
if (!tf_api_implements_attr) return;
// TODO(b/147536816): Keras lstm should set up the correct attributes.
ConvertTFAPIImplements(func, tf_api_implements_attr);
if (tf_api_implements_attr) {
// TODO(b/147536816): Keras lstm should set up the correct attributes.
ConvertTFAPIImplements(func, tf_api_implements_attr, module);
}
}
}
} // namespace

std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareCompositeFunctionsPass() {
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
return std::make_unique<PrepareCompositeFunctionsPass>();
}

Expand Down

0 comments on commit 03d5fe9

Please sign in to comment.