-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras LSTM not converted correctly when precedent Embedding layer specifies mask_zero=True #1871
Comments
Related issue: Keras LSTM converted to loop instead of ONNX LSTM op #1851 |
@hwangdeyu The reason why it's converted to loop in this case is because the As far as I understand, the masking behavior is an intra-op operation. cjermain in this post gave idea on solution to padding (not general masking) which is already implemented in keras2onnx - which tf2onnx can also adapt? |
Hi @hwangdeyu and tf2onnx contributors, This feature keras2onnx has implemented. I tested with the minimal example keras model shared above - keras2onnx implementation gives correct inference results on post-padded inputs. It does so by pre-processing the inputs before LSTM. Can the same be implemented here to support masking for post-padded inputs? |
Thank you so much for so detailed issue, I am not very familiar with this part. I guess the keras2onnx is a good example to fix it. Need to do more investigation for it. Is there any model blocked by this issues?
|
Hi, thank you for following up. Yes, we have an older model trained with TF-2.4 that converts with keras2onnx correctly. With keras2onnx deprecated and TF upgraded to TF-2.8, we need to correctly convert the model with tf2onnx. The production model is tricky to share but above minimal example should reproduce the issue. Will the community be able to take a look soon? |
Hi @hwangdeyu and tf2onnx community, is there any plan on picking up this work soon? |
@hwangdeyu we have some urgent request for this fix (07/01). Please let us know if you are able to pick this sometime soon. |
Yeah, sorry for late reply, we will make a plan to fix this recently. |
Hi @hwangdeyu, could you give us an update? We'll need to work around the urgent request (07/01) we got depending on your plan for this fix. Pls let us know. |
We went offsite several days. And we are doing with this issue, but the progress has been limited so far. I‘m not sure it can be solved before July 1st.. |
@hwangdeyu Thanks for the update. Could you provide us with an ETA? |
@hwangdeyu Given the request on our end, we need to work towards a fix too. Could you give us some pointer/direction so that we can contribute/collaborate? |
Just as onnx/onnx#2248 (comment) said, the first type has been implemented in keras2onnx. The second type may be better, but I don't think ONNX will change the KSTM op cause the recent new models are less likely to use it. I don't think of a better way so far and my plan will sync with my colleague to try the same with kears2onnx to support masking for post-padded inputs. And thanks, I will ping you if I find anything need you guy contribute.. |
Thanks for the pointer. Regarding the reproduction, I think the minimal example shared at top of post is sufficiently simple to verify wip fix? We can assert both inference result to equal the TF result, and assert LSTM op is correctly converted. It seems that linked test only verify inference result equal, but doesn't verify LSTM op is produced (instead of loop)? Maybe I didn't understand your ask correctly? |
Hi @q-ycong-p , I have synced with my colleague and did more tests about it. Just as the CI tests and your example shown, the inference results from original converted model
|
Hi @hwangdeyu, thanks for letting me know. Even inference result is correct, LSTM being converted to loop op means important optimizations cannot be taken advantaged of at inference time. I'm trying to work towards a fix for the post-padded masking scenario, using the example @cjermain from keras2onnx has provided. Will keep it posted and might seek help if stuck. |
Hello, |
Describe the bug
When an
tf.keras.layers.Embedding
with attributemask_zero=True
attribute precede LSTM layer, the LSTM is converted into loops instead of LSTM op.System information
To Reproduce
Screenshots
Additional context
I've tried modified
lstm_tf2_rewriter
to accommodate the new pattern in rewriter parsing. Although I can skip the extraSelectV2
pattern and get LSTM op in final onnx model, I am not able to correctly handle the masking zero information. My attempt will result in incorrect inference result if 0 is contained in input.Below is my unsuccessful attempt: masking is ignored and results in incorrect inference result. Any suggestion on how masking should be handled? Thank you!
Related links:
The text was updated successfully, but these errors were encountered: