Skip to content

Commit

Permalink
update time stamp for function signature: due to the while loop, the …
Browse files Browse the repository at this point in the history
…lstm function signature will have like tensor<1x?x10xf32>, but if we know the input, we can derive the timestamp.

PiperOrigin-RevId: 299236850
Change-Id: I4bf8d93fd724919a6829e72cae19af551bfeb873
  • Loading branch information
renjie-liu authored and tensorflower-gardener committed Mar 6, 2020
1 parent 07eff99 commit 1113b22
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,16 +588,6 @@ LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits,
return success();
}

void UpdateFuncSignature(int batch, int time, int output,
mlir::FuncOp* func_op) {
SmallVector<int64_t, 4> output_shape{batch, time, output};
auto input_types = func_op->getType().getInputs();
auto element_type = input_types[0].cast<RankedTensorType>().getElementType();
auto output_type = mlir::RankedTensorType::get(output_shape, element_type);
func_op->setType(
mlir::FunctionType::get(input_types, output_type, func_op->getContext()));
}

// TODO(b/147436982): Consider refactor this to be more general.
LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
// For argument order, please check out standard_lstm under
Expand Down Expand Up @@ -729,19 +719,29 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
}

SmallVector<Value, 5> outputs;
SmallVector<Type, 5> output_types;

// Due to the existence of the while loop, the timestamp may be unknown
// for the signature, for us, since we know the inputs, we can infer the time
// steps.
for (int i = 0; i < 5; ++i) {
if (i == 1) {
// only this one is the real output.
outputs.push_back(final_output);
output_types.push_back(final_output.getType());
} else {
auto result_type =
func_op.getCallableResults()[i].dyn_cast<RankedTensorType>();
outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f,
func_op.getLoc()));
output_types.push_back(result_type);
}
}

// Update function signatures.
func_op.setType(mlir::FunctionType::get(func_op.getType().getInputs(),
output_types, func_op.getContext()));

builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
return success();
}
Expand Down

0 comments on commit 1113b22

Please sign in to comment.