diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 0ae1349a77d33b..f74fc0cb70a36e 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -588,16 +588,6 @@ LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits, return success(); } -void UpdateFuncSignature(int batch, int time, int output, - mlir::FuncOp* func_op) { - SmallVector output_shape{batch, time, output}; - auto input_types = func_op->getType().getInputs(); - auto element_type = input_types[0].cast().getElementType(); - auto output_type = mlir::RankedTensorType::get(output_shape, element_type); - func_op->setType( - mlir::FunctionType::get(input_types, output_type, func_op->getContext())); -} - // TODO(b/147436982): Consider refactor this to be more general. LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { // For argument order, please check out standard_lstm under @@ -729,19 +719,29 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { } SmallVector outputs; + SmallVector output_types; + // Due to the existence of the while loop, the timestamp may be unknown + // for the signature, for us, since we know the inputs, we can infer the time + // steps. for (int i = 0; i < 5; ++i) { if (i == 1) { // only this one is the real output. outputs.push_back(final_output); + output_types.push_back(final_output.getType()); } else { auto result_type = func_op.getCallableResults()[i].dyn_cast(); outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f, func_op.getLoc())); + output_types.push_back(result_type); } } + // Update function signatures. + func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(), + output_types, func_op.getContext())); + builder->create(func_op.getLoc(), outputs); return success(); }