-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Training] Add more missing gradients #6767
Conversation
Please try to pass lint locally first~ |
|
||
|
||
@register_gradient("take") | ||
def take_grad(orig, grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can get by by defining a 'put' operator, that put a scalar into an index of a tensor, and leave other palces unchanged. put and take has some classic property which I assume will be better for the optimizer. It also allow other optimization (e.g. put and reduce_sum, using grad + (put vala at idxa in 0_array) + (put valb at idxb in 0_array) will be collapsed into a long chain of put on grad, allowing COW to kick in and all take grad mutation update (instead of creating another tensor).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jroesch please look and comment as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point that I was wondering about. The loop is basically just implementing a put
operation (like I described in the comment), so it would make sense to have it be a separate op since I imagine it will be useful in general. Should I remove this gradient for now, or keep it and replace it with put
once I implement it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both are fine.
merging per reviews, feel free to send followup improvements per @MarisaKirisame 's comment |
Added the following gradients:
take
reverse_reshape
stack
squeeze
expand_dims
arange
Also fixed a typo in type solver diagnostics. I had to use a Relay loop in
take_grad
to support Any size in indices.cc @t-vi @SWu @MarisaKirisame