-
Couldn't load subscription status.
- Fork 5.9k
fix adamw apply gradient #30130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix adamw apply gradient #30130
Conversation
|
Thanks for your contribution! |
e0f6567 to
7b8a46f
Compare
python/paddle/optimizer/adamw.py
Outdated
| assert param.dtype == paddle.fluid.core.VarDesc.VarType.FP32, \ | ||
| "the type of coeff(float) and parameter(%s) is not consistent."%(param.dtype) | ||
| else: | ||
| assert self._coeff.dtype == param.dtype, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉用户使用double有点麻烦,能否float时就不要求param.dtype呢,下面的decay_coeff = 1.0 - self._coeff * learning_rate一定会受影响而不是用learning_rate的dtype是吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是指把这整个判断逻辑都干掉吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
1.0 - self._coeff * learning_rate,返回的tensor类型是self._coeff的类型,如果self._coeff 和 learning_rate类型不一致,会自动加cast转换类型。
把判断逻辑去掉则支持任何类型,改成1.0 - learning_rate * self._coeff使用learning_rate的类型。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
APIs
Describe
paddle.optimizer.AdamW不兼容 #29794param = param - param * lr * coeffis optimized as followsparam = param * (1.0 - lr * coeff)