Skip to content

Commit

Permalink
Make aten.copy preserve strides (hf_Longformer) (pytorch#89464)
Browse files Browse the repository at this point in the history
Fixes https://github.com/pytorch/torchdynamo/issues/1888

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D41460986](https://our.internmc.facebook.com/intern/diff/D41460986)

Pull Request resolved: pytorch#89464
Approved by: https://github.com/bdhirsh
  • Loading branch information
ezyang authored and pytorchmergebot committed Nov 22, 2022
1 parent 2d94fd3 commit d9cbe77
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 46 deletions.
46 changes: 29 additions & 17 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,27 +278,39 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return self;
}

// NB: cribbed from https://github.com/pytorch/pytorch/pull/88198
at::Tensor clone_preserve_strides(const at::Tensor& self) {
TORCH_INTERNAL_ASSERT(self.has_storage());
// In cases where the input tensor has internal memory overlap, we cannot actually
// preserve the strides/storage_offset of the input tensor, because
// *_scatter ops will try to copy_() into the cloned tensor.
// However, this should **never** show up in functionalized user code;
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
//
// The one place that this does come up is in autograd - if there's a select_scatter
// in the forward, then autograd will generate one for the backward.
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
// with internal overlap.
//if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
// return self.clone();
//}
auto dtype_size = self.dtype().itemsize();
auto nbytes = self.storage().sym_nbytes();
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
auto numel = nbytes / dtype_size;
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
auto clone = self_full_size.clone();
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
return out;
}

Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
// (1) It isn't exposed to the frontend (no python bindings)
// (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
// Note: This implementation doesn't currently preserve the strides of `self`.
// That might be fine for functorch (which already doesn't preserve strides in vmap),
// but it's worth looking into whether or not this implementation will be problematic for LazyTensor/XLA.
auto intermediate = src.to(self, non_blocking);
// We can't use expand() here. Why?
// The contract for copy_() is that the output tensor has the same amount of storage as the original tensor.
// e.g. This should work:
// a = torch.ones(4, 4)
// b = torch.ones(1, 4)
// c = torch.ones(4, 4)
// torch.ops.aten.copy(a, b).add_(c)
// We don't want to emit an extra copy every time though, so we only do it if the shapes are different.
if (self.sym_sizes() != intermediate.sym_sizes()) {
return at::expand_copy_symint(intermediate, self.sym_sizes());
} else {
return intermediate;
}
auto r = clone_preserve_strides(self);
r.copy_(src, non_blocking);
return r;
}

Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,8 @@

- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: copy

- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
variants: method
Expand Down
65 changes: 38 additions & 27 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ def f(x):

_functionalize(f, reapply_views=True)(torch.ones(3, 3))

def test_copy_stride_mismatch(self):
def f(x):
y = torch.empty_strided((2, 2), (5, 1))
y.copy_(x)
return y

r = _functionalize(f, reapply_views=True)(torch.ones(2, 2))
self.assertEqual(r.stride(), (5, 1))

def test_view_clone_view_inplace(self):
def f(input):
shape = [1, 1024, 128, 128]
Expand Down Expand Up @@ -149,13 +158,15 @@ def forward(self, a_1):
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128])
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format); view_copy_5 = None
copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_3); new_empty_strided = view_copy_3 = None
view_copy_4 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
view_copy_5 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format)
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]); view_copy_3 = None
copy_1 = torch.ops.aten.copy.default(view_copy_5, threshold_backward); view_copy_5 = threshold_backward = None
view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None
detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None
view_copy_7 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None
view_copy_7 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None
view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None
return detach_copy_1
Expand Down Expand Up @@ -829,8 +840,8 @@ def f(x):
_z = torch._from_functional_tensor(z)
self.assertTrue(are_aliased(_y, _z))

# copy_() gets its own test, because it is special cased in functionalization.
# self.copy_(src) decomposes into src.to(self).expand_as(self).
# copy_() gets its own test, because it used to be special cased in functionalization.
# However, now it works pretty similar to other functional ops
def test_copy_(self):
def f(x):
tmp = torch.zeros(2, 2)
Expand All @@ -850,7 +861,8 @@ def f(x):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None
return add
""")

Expand All @@ -862,8 +874,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add
copy = torch.ops.aten.copy_.default(diagonal, a_1)
add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None
return diagonal
""")

# Test 2: copy_() with same dtype, different shape
Expand All @@ -876,8 +889,8 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None
return add
""")

Expand All @@ -889,9 +902,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
return expand_copy
copy = torch.ops.aten.copy_.default(diagonal, a_1)
add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None
return diagonal
""")

# Test 3: copy_() with different dtype, same shape
Expand All @@ -904,8 +917,8 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(_to_copy, a_1); _to_copy = a_1 = None
copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None
return add
""") # noqa: B950

Expand All @@ -917,9 +930,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add_.Tensor(_to_copy, a_1); a_1 = None
return _to_copy
copy = torch.ops.aten.copy_.default(diagonal, a_1)
add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None
return diagonal
""") # noqa: B950

# Test 4: copy_() with different dtype, different shape
Expand All @@ -932,9 +945,8 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None
add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None
return add
""") # noqa: B950

Expand All @@ -946,10 +958,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
return expand_copy
copy = torch.ops.aten.copy_.default(diagonal, a_1)
add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None
return diagonal
""") # noqa: B950

def test_expand_symint(self):
Expand Down
3 changes: 1 addition & 2 deletions test/test_fx_reinplace_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ def forward(self):
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_tensor = torch.ops.aten.slice.Tensor(slice_3, 1, 2, 9223372036854775807); slice_3 = None
copy__default = torch.ops.aten.copy_.default(slice_tensor, ones); slice_tensor = ones = None
return zeros
""")

Expand Down
11 changes: 11 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,17 @@ def all_dim(input, dim, keeepdim=False):
return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim))


# NB: this decomposition is not stride accurate, do not put it in the main
# library
@register_decomposition(aten.copy)
def copy(self, src, non_blocking=False):
intermediate = src.to(self, non_blocking)
if self.size() != intermediate.size():
return aten.expand_copy.default(intermediate, self.size())
else:
return intermediate


@register_decomposition(aten.hardswish_)
def hardswish_(x):
return x.copy_(aten.hardswish(x))
Expand Down

0 comments on commit d9cbe77

Please sign in to comment.