Open
Description
Describe the bug
When an tf.keras.layers.Embedding
with attribute mask_zero=True
attribute precede LSTM layer, the LSTM is converted into loops instead of LSTM op.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux
- Tensorflow Version: 2.4.0
- Python version: 3.8.12
To Reproduce
# minimal example of problematic keras model
x = tf.keras.layers.Input(shape=(4,), dtype="int32")
e = tf.keras.layers.Embedding(5, 5, mask_zero=True)(x)
rnn = tf.keras.layers.LSTM(3, return_sequences=True)(e)[0]
model = tf.keras.Model(inputs=x, outputs=rnn)
# converted onnx will have loops instead of LSTM op
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx.save(onnx_model, "lstm_masking_zero.onnx")
Additional context
I've tried modified lstm_tf2_rewriter
to accommodate the new pattern in rewriter parsing. Although I can skip the extra SelectV2
pattern and get LSTM op in final onnx model, I am not able to correctly handle the masking zero information. My attempt will result in incorrect inference result if 0 is contained in input.
Below is my unsuccessful attempt: masking is ignored and results in incorrect inference result. Any suggestion on how masking should be handled? Thank you!
--- a/tf2onnx/rewriter/lstm_tf2_rewriter.py
+++ b/tf2onnx/rewriter/lstm_tf2_rewriter.py
@@ -56,21 +56,22 @@ def rewriter_lstm_tf2(g, ops):
# extract output h_t
ht_mul = match_result.get_op("ht")
final_consumers = g.find_output_consumers(ht_mul.output[0])
- select_ops = [n for n in final_consumers if n.type == "Select"]
+ select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"]
def has_tensor_list_consumer(n):
return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
+
+ seq_len_idx = None
if len(select_ops) == 1:
- greater_eq = select_ops[0].inputs[0]
- if greater_eq.type != "GreaterEqual":
- continue
- seq_len = greater_eq.inputs[1]
- if not seq_len.is_graph_input():
- continue
- seq_len_idx = g.input_names.index(seq_len.output[0])
+ select_op_condition = select_ops[0].inputs[0]
+ if select_op_condition.type == "GreaterEqual": # has sequence length
+ seq_len = select_op_condition.inputs[1]
+ if not seq_len.is_graph_input():
+ continue
+ seq_len_idx = g.input_names.index(seq_len.output[0])
+ # if select op's condition doesn't come from GreaterEqual, we still extract
+ # output h_t from consumer, and seq_len remains empty
final_consumers = g.find_output_consumers(select_ops[0].output[0])
- else:
- seq_len_idx = None
Related links:
- Masking behavior in tf.keras and how RNN behavior should change: Keras LSTM not converted correctly when precedent Embedding layer specifies mask_zero=True #1871
- ONNX open issue on masking support for RNN layer: Support masking for LSTM, RNN? onnx#2248