Skip to content

Keras LSTM not converted correctly when precedent Embedding layer specifies mask_zero=True #1871

Open
@q-ycong-p

Description

@q-ycong-p

Describe the bug
When an tf.keras.layers.Embedding with attribute mask_zero=True attribute precede LSTM layer, the LSTM is converted into loops instead of LSTM op.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux
  • Tensorflow Version: 2.4.0
  • Python version: 3.8.12

To Reproduce

# minimal example of problematic keras model
x = tf.keras.layers.Input(shape=(4,), dtype="int32")
e = tf.keras.layers.Embedding(5, 5, mask_zero=True)(x)
rnn = tf.keras.layers.LSTM(3, return_sequences=True)(e)[0]
model = tf.keras.Model(inputs=x, outputs=rnn)

# converted onnx will have loops instead of LSTM op
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx.save(onnx_model, "lstm_masking_zero.onnx")

Screenshots
Original keras model to be converted

Unsuccessful conversion of lstm into loops

Additional context
I've tried modified lstm_tf2_rewriter to accommodate the new pattern in rewriter parsing. Although I can skip the extra SelectV2 pattern and get LSTM op in final onnx model, I am not able to correctly handle the masking zero information. My attempt will result in incorrect inference result if 0 is contained in input.

Below is my unsuccessful attempt: masking is ignored and results in incorrect inference result. Any suggestion on how masking should be handled? Thank you!

--- a/tf2onnx/rewriter/lstm_tf2_rewriter.py
+++ b/tf2onnx/rewriter/lstm_tf2_rewriter.py
@@ -56,21 +56,22 @@ def rewriter_lstm_tf2(g, ops):
             # extract output h_t
             ht_mul = match_result.get_op("ht")
             final_consumers = g.find_output_consumers(ht_mul.output[0])
-            select_ops = [n for n in final_consumers if n.type == "Select"]
+            select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"]
             def has_tensor_list_consumer(n):
                 return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
             select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
+
+            seq_len_idx = None
             if len(select_ops) == 1:
-                greater_eq = select_ops[0].inputs[0]
-                if greater_eq.type != "GreaterEqual":
-                    continue
-                seq_len = greater_eq.inputs[1]
-                if not seq_len.is_graph_input():
-                    continue
-                seq_len_idx = g.input_names.index(seq_len.output[0])
+                select_op_condition = select_ops[0].inputs[0]
+                if select_op_condition.type == "GreaterEqual": # has sequence length
+                    seq_len = select_op_condition.inputs[1]
+                    if not seq_len.is_graph_input():
+                        continue
+                    seq_len_idx = g.input_names.index(seq_len.output[0])
+                # if select op's condition doesn't come from GreaterEqual, we still extract
+                # output h_t from consumer, and seq_len remains empty
                 final_consumers = g.find_output_consumers(select_ops[0].output[0])
-            else:
-                seq_len_idx = None

Related links:

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions