-
Notifications
You must be signed in to change notification settings - Fork 7
Closed
Description
Hello
It seems as though there is a minor error in computing RMS in Stable Adam:
RMS is computed as:
| rms = torch.norm(p.grad.data.div(root_sqr_avg.maximum(eps_t)), 2) |
And is missing a normalization coefficient for the mean:
rms = torch.norm(p.grad.data.div(root_sqr_avg.maximum(eps_t)), 2) / math.sqrt(p.grad.data.div.numel())
Additionally, another eps can be added in the final normalization.
Metadata
Metadata
Assignees
Labels
No labels
