forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSymbolicShapeMeta.cpp
317 lines (292 loc) · 10.6 KB
/
SymbolicShapeMeta.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#include <c10/core/Contiguity.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/SymbolicShapeMeta.h>
namespace c10 {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other)
// Non-mutables can be accessed outside the mutex
: sizes_(other.sizes_),
strides_(other.strides_),
storage_offset_(other.storage_offset_),
strides_valid_(other.strides_valid_) {
std::scoped_lock lock(other.mutables_);
// These must be copied under lock, so ignore clang-tidy here!
// NOLINTBEGIN(cppcoreguidelines-prefer-member-initializer)
numel_ = other.numel_;
is_contiguous_ = other.is_contiguous_;
is_channels_last_contiguous_ = other.is_channels_last_contiguous_;
is_channels_last_3d_contiguous_ = other.is_channels_last_3d_contiguous_;
is_channels_last_ = other.is_channels_last_;
is_channels_last_3d_ = other.is_channels_last_3d_;
is_non_overlapping_and_dense_ = other.is_non_overlapping_and_dense_;
available_.store(other.available_.load());
// NOLINTEND(cppcoreguidelines-prefer-member-initializer)
}
// base, sizes, strides
static c10::optional<
std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>>
normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
// Look for a SymNode to dispatch on
SymNode base;
bool all_hinted = true;
// NB: sizes/strides guaranteed to be positive, so only need
// is_heap_allocated
for (const auto& s : sizes) {
if (all_hinted && !s.has_hint()) {
all_hinted = false;
}
if (!base && s.is_heap_allocated()) {
base = s.toSymNode();
}
}
for (const auto& s : strides) {
if (all_hinted && !s.has_hint()) {
all_hinted = false;
}
if (!base && s.is_heap_allocated()) {
base = s.toSymNode();
}
}
if (!base || all_hinted) {
// Couldn't find. Tell the caller to do the normal computation
// Alternately, if everything is hinted, we want the normal computation
// too
return c10::nullopt;
}
// Populate the SymNode array
std::vector<SymNode> size_nodes;
std::vector<SymNode> stride_nodes;
size_nodes.reserve(sizes.size());
stride_nodes.reserve(strides.size());
for (const auto& s : sizes) {
size_nodes.emplace_back(s.wrap_node(base));
}
for (const auto& s : strides) {
stride_nodes.emplace_back(s.wrap_node(base));
}
return c10::make_optional(
std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
std::move(base), std::move(size_nodes), std::move(stride_nodes)));
}
// Special treatment because of numel
SymBool SymbolicShapeMeta::compute_contiguous() const {
if (!strides_valid_) {
return false;
}
c10::SymIntArrayRef sizes(sizes_);
c10::SymIntArrayRef strides(strides_);
return _compute_contiguous(sizes, strides, numel());
}
// The rest of them
#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
SymBool SymbolicShapeMeta::name() const { \
if (!strides_valid_) { \
return false; \
} \
c10::SymIntArrayRef sizes(sizes_); \
c10::SymIntArrayRef strides(strides_); \
return fallback(sizes, strides); \
}
#define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
SymBool SymbolicShapeMeta::name() const { \
if (!strides_valid_) { \
return false; \
} \
auto n = normalize_sym_sizes_strides(sizes_, strides_); \
if (n.has_value()) { \
auto [base, size_nodes, stride_nodes] = *n; \
return SymBool(base->nodeimpl(size_nodes, stride_nodes)); \
} else { \
c10::SymIntArrayRef sizes(sizes_); \
c10::SymIntArrayRef strides(strides_); \
return fallback(sizes, strides); \
} \
}
// clang-format off
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d)
DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense)
// clang-format on
#undef DEFINE_SYMBOOL_COMPUTE
// Glue compute
// NB: this logic very intentionally short circuits if possible. Without
// short circuiting, it causes
// python test/functorch/test_aotdispatch.py -k
// test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run
// very slowly.
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
init_is_contiguous();
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
return true;
}
init_is_channels_last_contiguous();
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return true;
}
return is_contiguous() | is_channels_last_contiguous() |
compute_non_overlapping_and_dense();
}
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
init_is_channels_last_contiguous();
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return false;
}
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
}
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
init_is_channels_last_3d_contiguous();
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
return false;
}
return ~is_channels_last_3d_contiguous() &
compute_strides_like_channels_last_2d();
}
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
return false;
}
return ~is_channels_last() & compute_strides_like_channels_last_3d();
}
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
return true;
}
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return true;
}
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
return true;
}
return is_contiguous() | is_channels_last_contiguous() |
is_channels_last_3d_contiguous() | compute_non_overlapping_and_dense();
}
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
return true;
}
return is_contiguous() | compute_non_overlapping_and_dense();
}
// NOLINTNEXTLINE(performance-unnecessary-value-param)
void SymbolicShapeMeta::set_numel(SymInt val) const {
std::scoped_lock lock(mutables_);
if (has_numel()) {
return;
}
numel_ = std::move(val);
available_.fetch_or(numel_avail);
}
void SymbolicShapeMeta::set_is_contiguous(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_contiguous()) {
return;
}
is_contiguous_ = std::move(val);
available_.fetch_or(is_contiguous_avail);
}
void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last_contiguous()) {
return;
}
is_channels_last_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_contiguous_avail);
}
void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last_3d_contiguous()) {
return;
}
is_channels_last_3d_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_3d_contiguous_avail);
}
void SymbolicShapeMeta::set_is_channels_last(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last()) {
return;
}
is_channels_last_ = std::move(val);
available_.fetch_or(is_channels_last_avail);
}
void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last_3d()) {
return;
}
is_channels_last_3d_ = std::move(val);
available_.fetch_or(is_channels_last_3d_avail);
}
void SymbolicShapeMeta::set_is_non_overlapping_and_dense(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_non_overlapping_and_dense()) {
return;
}
is_non_overlapping_and_dense_ = std::move(val);
available_.fetch_or(is_non_overlapping_and_dense_avail);
}
void SymbolicShapeMeta::init_numel() const {
set_numel(multiply_integers(sizes_));
}
void SymbolicShapeMeta::init_is_contiguous() const {
set_is_contiguous(compute_contiguous());
}
void SymbolicShapeMeta::init_is_channels_last_contiguous() const {
set_is_channels_last_contiguous([&] {
switch (dim()) {
case 5:
case 4: {
return compute_channels_last_contiguous_2d();
}
default:
return SymBool{false};
}
}());
}
void SymbolicShapeMeta::init_is_channels_last_3d_contiguous() const {
set_is_channels_last_3d_contiguous([&] {
switch (dim()) {
case 5:
return compute_channels_last_contiguous_3d_dim5();
default:
return SymBool{false};
}
}());
}
void SymbolicShapeMeta::init_is_channels_last() const {
set_is_channels_last([&] {
switch (dim()) {
case 5:
return compute_channels_last_2d_dim5();
case 4:
return compute_strides_like_channels_last_2d();
default:
return SymBool{false};
}
}());
}
void SymbolicShapeMeta::init_is_channels_last_3d() const {
set_is_channels_last_3d([&] {
switch (dim()) {
case 5:
return compute_channels_last_3d_dim5();
default:
return SymBool{false};
}
}());
}
void SymbolicShapeMeta::init_is_non_overlapping_and_dense() const {
set_is_non_overlapping_and_dense([&] {
switch (dim()) {
case 5:
return compute_is_non_overlapping_and_dense_dim5();
case 4:
return compute_is_non_overlapping_and_dense_dim4();
default:
return compute_is_non_overlapping_and_dense_anydim();
}
}());
}
} // namespace c10