-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][A-43] Add type annotations for paddle/optimizer/lamb.py
#65247
[Typing][A-43] Add type annotations for paddle/optimizer/lamb.py
#65247
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/optimizer/lamb.py
Outdated
assert learning_rate is not None | ||
assert beta1 is not None | ||
assert beta2 is not None | ||
assert epsilon is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除非是错误的逻辑,否则不要修改运行时代码
python/paddle/optimizer/lamb.py
Outdated
learning_rate: float | Tensor = 0.001, | ||
lamb_weight_decay: float = 0.01, | ||
beta1: float = 0.9, | ||
beta2: float = 0.999, | ||
epsilon: float = 1e-6, | ||
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None, | ||
grad_clip: GradientClipBase | None = None, | ||
exclude_from_weight_decay_fn: Callable | None = None, | ||
multi_precision: bool = False, | ||
always_adapt: bool = False, | ||
name: str | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__init__
的返回值 None 写一下吧 ~
另外,exclude_from_weight_decay_fn: Callable | None = None
确认一下是不是
exclude_from_weight_decay_fn: Callable[[Tensor], bool] | None = None
): | ||
learning_rate: float | Tensor = 0.001, | ||
lamb_weight_decay: float = 0.01, | ||
beta1: float = 0.9, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例代码
>>> import paddle
>>> inp = paddle.uniform(shape=[10, 10], dtype='float32', min=-0.1, max=0.1)
>>> linear = paddle.nn.Linear(10, 10)
>>> out = linear(inp)
>>> loss = paddle.mean(out)
>>> beta1 = paddle.to_tensor([0.9], dtype="float32")
>>> beta2 = paddle.to_tensor([0.85], dtype="float32")
>>> lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
>>> back = out.backward()
>>> lamb.step()
>>> lamb.clear_grad()
这里定义了 beta1 、 beta2 为 Tensor 类型,但是并没有传入到 Lamb 中。#65236 的示例代码里用到了,所以添加了 Tensor,然后传进去不报错,很奇怪要不要加 Tensor 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个就先不加了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯 ~ 其实,我感觉 betas 没必要支持 Tensor ~ 参考 torch,只要 float 就行 ~
至于为什么这里可以传 Tensor,之前我也没太留意,刚才调试了一下,ops.yaml
里面应该是把这几个当作的 attribute
,以 nadam
为例:
(Tensor param, Tensor grad, Tensor learning_rate, Tensor momentum_decay_pow, Tensor beta2_pow, Tensor mu_product, Tensor moment1, Tensor moment2, Tensor master_param, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1.0e-8f, float momentum_decay = 0.004f, bool multi_precision = false)
日志中:
I0618 20:54:07.048274 76391 api.cc:33036] nadam kernel: {"input":["GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32"],"output":["GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32","GPU, NCHW, float32"],"attribute":["float","float","float","float","bool"]}
然后,如果调用接口的时候传入的是 Tensor,框架会尝试 cast
为 float
或 double
,paddle/fluid/pybind/op_function_common.cc
:
double CastPyArg2Double(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) {
return PyFloat_AsDouble(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidType(
"%s(): argument (position %d) must be "
"double, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0.0;
}
所以,我觉得这里其实没有必要加 Tensor
,能够兼容,只是框架支持了,但不是接口设计的本意吧 (包括示例代码,感觉其实也没必要用 Tensor 形式的 beta) ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Category
User Experience
PR Types
Improvements
Description
类型标注:
Related links
@SigureMo @megemini