-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Calculating gradients for partial graph #7269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
return op_path | ||
|
||
|
||
def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few questions about the calc_gradient
:
- It seems that this API appends partial backward to calculate gradients of
targets
toinputs
. Can this API be used by users directly? And when will users use it? - Here are some special handlings(like
_rename_grad_
) to make sure this API can be invoked for multiple times. However, in each invoking theprogram
itself will be changed(some backward ops are appended). So during the second invoke, the backward part generated in the first time is still in the program and is regarded as forwarding part. Is this reasonable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes. Some algorithm have its own way of calculating gradients. One example is Learning to learn by gradient descent by gradient descent.
- Yes. This happens in cases where gradients from different parts needs to be combined in some way to calculate the final result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it!
Added backward.calc_gradient to backpropagate gradient from given targets to inputs.
5cdbefe
to
6e5eae1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Added backward.calc_gradient to backpropagate gradient from given targets to inputs.