-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Frontend][Tensorflow] Fix GatherV2, Add StopGradient #4238
Conversation
@trevor-m Thank you for adding this. Can you take a look at the CI failure? |
Thanks @kevinthesun, the link http://ci.tvm.ai:8080/job/tvm/job/PR-4238/1/display/redirect isn't working for me. Edit:
|
please kick ci later. |
I object mapping it to identity directly, because we have an ad algorithm. Can it be replaced with an operator that map to identity, and has grad of 0? |
That makes sense to me. How do you see this being implemented? Do you mean creating a new op in relay or is there a way to do it just in the parser? Looks like nnvm had an op |
You have to create an op in relay. |
That sounds good, thanks for the description! I’ll update this PR with the change shortly. |
…4238) * Add StopGradient. Add batch_dims attr to ignore list for GatherV2 * Trigger CI
These changes are needed to parse a BERT TF model into relay.
GatherV2
TensorFlow recently added
batch_dims
as an attribute of the GatherV2 op to support batched gather. This would cause the parser the fail since it didn't expect that attribute. I added a check to make sure the attribute doesn't change the behavior of the op if it happens to exist and also added it to the ignore list. We can support parsing of batched gather at a later date.StopGradient
This is a training op meant to stop gradient flow over certain edges in the graph. It has no purpose in inference so it can be safely mapped to identity.