Skip to content

Commit fd61cc9

Browse files
pbelevichfacebook-github-bot
authored andcommitted
Moved at::assert_no_internal_overlap to TensorIterator
Summary: Pull Request resolved: pytorch#22917 Differential Revision: D16521429 Pulled By: pbelevich fbshipit-source-id: 80ae583c6486d6948431b79e1452902bdf2cfbc3
1 parent 4b78ce1 commit fd61cc9

File tree

8 files changed

+55
-32
lines changed

8 files changed

+55
-32
lines changed

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,15 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
2323
return MemOverlap::TOO_HARD;
2424
}
2525

26-
void assert_no_internal_overlap(const Tensor& t, const std::string& op) {
27-
assert_no_internal_overlap(t.unsafeGetTensorImpl(), op);
26+
void assert_no_internal_overlap(const Tensor& t) {
27+
assert_no_internal_overlap(t.unsafeGetTensorImpl());
2828
}
2929

30-
void assert_no_internal_overlap(TensorImpl* t, const std::string& op) {
31-
if (has_internal_overlap(t) == MemOverlap::YES) {
32-
AT_ERROR(
33-
op, ": unsupported operation: more than one element of the written-to "
34-
"tensor refers to a single memory location. Please clone() the tensor "
35-
"before calling ", op);
36-
}
30+
void assert_no_internal_overlap(TensorImpl* t) {
31+
TORCH_CHECK(has_internal_overlap(t) != MemOverlap::YES,
32+
"unsupported operation: more than one element of the written-to tensor "
33+
"refers to a single memory location. Please clone() the tensor before "
34+
"performing the operation.");
3735
}
3836

3937
}

aten/src/ATen/MemoryOverlap.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ enum class MemOverlap { NO, YES, TOO_HARD };
1616
CAFFE2_API MemOverlap has_internal_overlap(const Tensor& t);
1717
CAFFE2_API MemOverlap has_internal_overlap(TensorImpl* t);
1818

19-
CAFFE2_API void assert_no_internal_overlap(const Tensor& t, const std::string& op);
20-
CAFFE2_API void assert_no_internal_overlap(TensorImpl* t, const std::string& op);
19+
CAFFE2_API void assert_no_internal_overlap(const Tensor& t);
20+
CAFFE2_API void assert_no_internal_overlap(TensorImpl* t);
2121

2222
}

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
2525
} else if (self.is_sparse()) {
2626
AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
2727
}
28-
at::assert_no_internal_overlap(result, "add");
29-
auto iter = TensorIterator::binary_op(result, self, other);
28+
auto iter = TensorIterator::binary_op(result, self, other,
29+
/*check_internal_overlap=*/true);
3030
add_stub(iter.device_type(), iter, alpha);
3131
return result;
3232
}
@@ -54,8 +54,8 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
5454
}
5555
return at::_sparse_div_zerodim_out(result, self, other);
5656
}
57-
at::assert_no_internal_overlap(result, "div");
58-
auto iter = TensorIterator::binary_op(result, self, other);
57+
auto iter = TensorIterator::binary_op(result, self, other,
58+
/*check_internal_overlap=*/true);
5959
div_stub(iter.device_type(), iter);
6060
return result;
6161
}
@@ -79,8 +79,8 @@ Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
7979
if (self.is_sparse() || other.is_sparse()) {
8080
return at::_sparse_mul_out(result, self, other);
8181
}
82-
at::assert_no_internal_overlap(result, "mul");
83-
auto iter = TensorIterator::binary_op(result, self, other);
82+
auto iter = TensorIterator::binary_op(result, self, other,
83+
/*check_internal_overlap=*/true);
8484
mul_stub(iter.device_type(), iter);
8585
return result;
8686
}
@@ -125,8 +125,8 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
125125
} else if (self.is_sparse()) {
126126
AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead.");
127127
}
128-
at::assert_no_internal_overlap(result, "sub");
129-
auto iter = TensorIterator::binary_op(result, self, other);
128+
auto iter = TensorIterator::binary_op(result, self, other,
129+
/*check_internal_overlap=*/true);
130130
sub_stub(iter.device_type(), iter, alpha);
131131
return result;
132132
}

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,19 +534,29 @@ void TensorIterator::select_all_keeping_dim(int start_dim, IntArrayRef indices)
534534
}
535535
}
536536

537-
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) {
537+
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
538+
const Tensor& b, bool check_internal_overlap) {
538539
auto iter = TensorIterator();
539-
iter.add_output(out);
540+
if (check_internal_overlap) {
541+
iter.check_and_add_output(out);
542+
} else {
543+
iter.add_output(out);
544+
}
540545
iter.add_input(a);
541546
iter.add_input(b);
542547
iter.allow_cpu_scalars_ = true;
543548
iter.build();
544549
return iter;
545550
}
546551

547-
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {
552+
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a,
553+
bool check_internal_overlap) {
548554
auto iter = TensorIterator();
549-
iter.add_output(out);
555+
if (check_internal_overlap) {
556+
iter.check_and_add_output(out);
557+
} else {
558+
iter.add_output(out);
559+
}
550560
iter.add_input(a);
551561
iter.num_outputs_ = 1;
552562
iter.build();

aten/src/ATen/native/TensorIterator.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <ATen/detail/ScalarTypeConversions.h>
77
#include <bitset>
88
#include <c10/util/Optional.h>
9+
#include <ATen/MemoryOverlap.h>
910
#ifdef BUILD_NAMEDTENSOR
1011
#include <ATen/NamedTensorUtils.h>
1112
#endif
@@ -142,8 +143,10 @@ struct CAFFE2_API TensorIterator {
142143

143144
void foreach_reduced_elt(const loop_subiter_t& loop, bool parallelize=true);
144145

145-
static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b);
146-
static TensorIterator unary_op(Tensor& out, const Tensor& a);
146+
static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b,
147+
bool check_internal_overlap = false);
148+
static TensorIterator unary_op(Tensor& out, const Tensor& a,
149+
bool check_internal_overlap = false);
147150
static TensorIterator nullary_op(Tensor& out);
148151
static TensorIterator reduce_op(Tensor& out, const Tensor& a);
149152
static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);
@@ -261,6 +264,11 @@ struct CAFFE2_API TensorIterator {
261264
num_outputs_++;
262265
}
263266

267+
void check_and_add_output(const Tensor& output) {
268+
assert_no_internal_overlap(output);
269+
add_output(output);
270+
}
271+
264272
void add_output(const Tensor& input, Device device, ScalarType dtype) {
265273
operands_.emplace_back(input, device, dtype);
266274
num_outputs_++;
@@ -312,7 +320,6 @@ struct CAFFE2_API TensorIterator {
312320
bool promote_gpu_output_dtypes_ = false;
313321
bool final_output_ = true;
314322
};
315-
316323
/// A container-like struct that acts as if it contains splits of a
317324
/// TensorIterator that can use 32-bit indexing. Taken together the splits cover
318325
/// the original TensorIterator.

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ Tensor& bitwise_not_(Tensor& self) {
4747

4848
Tensor& bitwise_not_out(Tensor& result, const Tensor& self) {
4949
checkBackend("bitwise_not", result, self.type().backend());
50-
assert_no_internal_overlap(result, "bitwise_not");
51-
auto iter = TensorIterator::unary_op(result, self);
50+
auto iter = TensorIterator::unary_op(result, self,
51+
/*check_internal_overlap=*/true);
5252
bitwise_not_stub(iter.device_type(), iter);
5353
#ifdef BUILD_NAMEDTENSOR
5454
at::namedinference::propagate_names(result, self);
@@ -161,8 +161,8 @@ static void propagate_names_if_namedtensor_enabled(Tensor& result, const Tensor&
161161
} \
162162
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
163163
checkBackend(#op, result, Backend::CPU); \
164-
assert_no_internal_overlap(result, #op); \
165-
auto iter = TensorIterator::unary_op(result, self); \
164+
auto iter = TensorIterator::unary_op(result, self, \
165+
/*check_internal_overlap=*/true); \
166166
op##_stub(iter.device_type(), iter); \
167167
return result; \
168168
}

aten/src/THC/generic/THCTensorMathPointwise.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ static void propagate_names_if_named_tensor_enabled(THCTensor* result, THCTensor
196196
}; \
197197
\
198198
void THCTensor_(NAME)(THCState* state, THCTensor* self_, THCTensor* src) { \
199-
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); \
200-
at::assert_no_internal_overlap(self_, #NAME); \
199+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); \
200+
at::assert_no_internal_overlap(self_); \
201201
if (self_ == src) { \
202202
if (!THC_pointwiseApply1<scalar_t>(state, self_, Tensor_##NAME##_##REAL##_Op())) { \
203203
THArgCheck(false, 2, CUTORCH_DIM_WARNING); \

test/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12197,6 +12197,14 @@ def test_sinh_unary_mem_overlap(self):
1219712197
def test_cosh_unary_mem_overlap(self):
1219812198
self.unary_check_mem_overlap(lambda t: t.cosh_())
1219912199

12200+
@unittest.expectedFailure
12201+
def test_lerp_mem_overlap(self):
12202+
start = torch.randn(1, device=device).expand(3, 3)
12203+
end = torch.randn(3, 3, device=device)
12204+
weight = torch.randn(3, 3, device=device)
12205+
with self.assertRaisesRegex(RuntimeError, 'single memory location'):
12206+
start.lerp_(end, weight)
12207+
1220012208
@unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
1220112209
def test_reverse_binary_ops_multiple_device(self):
1220212210
self.assertEqual(2 + torch.tensor(3), 2 + torch.tensor(3).to("cuda:1")) # __radd__

0 commit comments

Comments
 (0)