forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathScatterGatherChecks.h
128 lines (112 loc) · 3.6 KB
/
ScatterGatherChecks.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once
#include <vector>
#include <ATen/ATen.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <c10/util/irange.h>
namespace at { namespace native {
namespace {
// checks whether index.dtype == int64
// and self.dtype == src.dtype if src is a Tensor
static void scatter_gather_dtype_check(
const std::string& method_name,
const Tensor& self,
const Tensor& index,
const c10::optional<Tensor>& src_opt = c10::nullopt
) {
if (index.numel() != 0) {
TORCH_CHECK(
index.scalar_type() == at::ScalarType::Long,
method_name, "(): Expected dtype int64 for index"
);
}
if (src_opt.has_value()) {
auto src = src_opt.value();
TORCH_CHECK(
self.scalar_type() == src.scalar_type(),
method_name, "(): Expected self.dtype to be equal to src.dtype"
);
}
}
// Used for `gather`-like methods
// Note: self means the input tensor here
// Test:
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.dim() == self.dim()
static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim,
const Tensor& index
) {
auto self_dims = ensure_nonempty_dim(self.dim());
TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as input tensor"
);
for (const auto i : c10::irange(self_dims)) {
if (i != dim) {
TORCH_CHECK(
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
"Size does not match at dimension ", i,
" expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim
);
}
}
}
// Used for `scatter` and `scatter_add`
// Tests:
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
// 3. index.dim() == self.dim() == src.dim()
static C10_UNUSED void scatter_shape_check(
const Tensor& self, int64_t dim, const Tensor& index,
const c10::optional<Tensor>& src_opt = c10::nullopt
) {
if (index.numel() == 0) return;
TORCH_CHECK(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as self tensor"
);
bool is_wrong_shape = false;
int64_t self_dims = ensure_nonempty_dim(self.dim());
// Check: index.size(d) <= self.size(d) for all d != dim
for (const auto d : c10::irange(self_dims)) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (d == dim) continue;
if (index_d_size > ensure_nonempty_size(self, d)) {
is_wrong_shape = true;
break;
}
}
// Check: index.size(d) <= src.size(d) for all d if src is Tensor
if (!is_wrong_shape && src_opt.has_value()) {
auto src = src_opt.value();
for (const auto d : c10::irange(self_dims)) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (index_d_size > ensure_nonempty_size(src, d)) {
is_wrong_shape = true;
break;
}
}
}
if (src_opt.has_value()) {
auto src = src_opt.value();
TORCH_CHECK(
ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as src tensor"
);
TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim,
" and to be smaller size than src ", src.sizes()
);
}
else {
TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim
);
}
}
} // anonymous namespace
}} // namespace at::native