Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong inplace acc grad #6146

Merged
merged 10 commits into from
Sep 3, 2021
9 changes: 8 additions & 1 deletion oneflow/core/autograd/autograd_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
}
if (autograd_meta->acc_grad()) {
DevVmDepObjectConsumeModeGuard guard(DevVmDepObjectConsumeMode::NONE);
// Should not inplace accumulate grad. For example,
// >>> z = x + y
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
const auto& output =
JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*inplace=*/true));
JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*inplace=*/false));
JUST(autograd_meta->set_acc_grad(output));
} else {
JUST(autograd_meta->set_acc_grad(current_grad));
Expand Down
9 changes: 8 additions & 1 deletion oneflow/core/framework/tensor_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ Maybe<void> TensorArg::PushPartialTensor(const std::shared_ptr<Tensor>& partial_
if (!acc_tensor_) {
acc_tensor_ = partial_tensor;
} else {
acc_tensor_ = JUST(functional::Add(partial_tensor, acc_tensor_, /*inplace=*/true));
// Should not inplace accumulate grad. For example,
// >>> z = x + y
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
acc_tensor_ = JUST(functional::Add(partial_tensor, acc_tensor_, /*inplace=*/false));
}
return Maybe<void>::Ok();
}
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/test/modules/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from test_util import GenArgList

from automated_test_util import *
import oneflow as flow
import oneflow.unittest

Expand Down Expand Up @@ -85,6 +86,14 @@ def test_autograd_interface(test_case):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3)
def test_accumulate_grad(test_case):
device = random_device()
ndim = random(1, 4).to(int)
x = random_pytorch_tensor(ndim=ndim, requires_grad=True).to(device)
y = random_pytorch_tensor(ndim=ndim, requires_grad=True).to(device)
return x / (x + y)


if __name__ == "__main__":
unittest.main()