You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With the addition of AMP<F> dtype, we also need to add gradient scaling, which is commonly used with AMP training.
I think the frontend interface could look something like:
letmut scaler = GradientScaler{ ...};// similar fields to pytorch scalar// this would do both parts that you have to do in pytorch now:// 1. would scale the loss by the correct value// 2. would unscale the gradients before returning them
grads = scaler.scaled_backward(loss);
We may have to add some methods to Gradients to support scaling them.
With the addition of
AMP<F>
dtype, we also need to add gradient scaling, which is commonly used with AMP training.I think the frontend interface could look something like:
We may have to add some methods to Gradients to support scaling them.
Originally posted by @coreylowman in #424 (comment)
The text was updated successfully, but these errors were encountered: