-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AutoMatedTest support test module.parameter.grad #6043
Conversation
dual_objects_to_test.append( | ||
GetDualObject( | ||
"unused", | ||
getattr(x.pytorch, key).grad, | ||
getattr(x.oneflow, key).grad, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加参数梯度的对比
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leaves-zwx 在这里加了对 module 参数梯度的对比,才能找到 momentum 没对齐时,参数更新错的问题
目前 bn 的 cpu 实现还没有对齐,torch 的计算公式为:
|
我们cpu和这个计算公式应该就是对齐的吧,后面那个等式可能会造成精度差异。 |
我又核对了一下,是 running_mean 和 running_var 错了 |
self.__setattr__("running_mean", running_mean) | ||
self.__setattr__("running_var", running_var) | ||
# use unbiased variance to update running_var | ||
unbiased_variance = x.var(dim=reduce_axis, unbiased=True, keepdim=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
更新 running_var 时用了无偏估计,后面计算的时候用的是真正的方差
自动测试给出哪个 module 的哪个参数没对齐的提示,这个是打算怎么做呢? |
游离的 tensor 不好办,module 里的参数都带名字的,这个在创建比较集合的时候就把名字传进去,对比出错打印的时候可以打出来 |
CI failed, removing label automerge |
需不需要给batchnorm加入其他参数如 affine 的测试呢? 如果affine为False,这里是运行有错的,要给 functor的gamma和 beta 设置为Optional |
这里就要functor 支持了,我记个 TODO 另外提一个 PR 来改 |
@@ -158,6 +164,8 @@ def forward(self, x): | |||
else: | |||
if self.training: | |||
is_training = True | |||
if self.track_running_stats: | |||
self.num_batches_tracked += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能写 += 。。。。。因为会触发 Inplace Add,推导 Consistent SBP 有 BUG
Speed stats:
|
TODO: