Skip to content

Add example of when .data can be unsafe #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions _posts/2018-04-22-0_4_0-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,41 @@ True

However, ``.data`` can be unsafe in some cases. Any changes on ``x.data`` wouldn't be tracked by ``autograd``, and the computed gradients would be incorrect if ``x`` is needed in a backward pass. A safer alternative is to use [``x.detach()``](http://pytorch.org/docs/master/autograd.html#torch.Tensor.detach), which also returns a ``Tensor`` that shares data with ``requires_grad=False``, but will have its in-place changes reported by ``autograd`` if ``x`` is needed in backward.

Here is an example of the difference between ``.data`` and ``x.detach()`` (and why we recommend using ``detach`` in general).

If you use ``Tensor.detach()``, the gradient computation is guaranteed to be correct.

```
>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> c = out.detach()
>>> c.zero_()
tensor([ 0., 0., 0.])

>>> out # modified by c.zero_() !!
tensor([ 0., 0., 0.])

>>> out.sum().backward() # Requires the original value of out, but that was overwritten by c.zero_()
RuntimeError: one of the variables needed for gradient computation has been modified by an
```

However, using ``Tensor.data`` can be unsafe and can easly result in incorrect gradients
when a tensor is required for gradient computation but modified in-place.

```
>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> c = out.data
>>> c.zero_()
tensor([ 0., 0., 0.])

>>> out # out was modified by c.zero_()
tensor([ 0., 0., 0.])

>>> out.sum().backward()
>>> a.grad # The result is very, very wrong because `out` changed!
tensor([ 0., 0., 0.])
```

## Support for 0-dimensional (scalar) Tensors

Expand Down