forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSymbolicShapeMeta.h
214 lines (188 loc) · 6.95 KB
/
SymbolicShapeMeta.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#pragma once
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/DimVector.h>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <utility>
namespace c10 {
class C10_API SymbolicShapeMeta {
public:
// Basic metadata from which other quantities are derived
SymDimVector sizes_ = {0};
SymDimVector strides_ = {1};
SymInt storage_offset_ = 0;
bool strides_valid_ = true; // e.g. for sparse where there are no strides
SymbolicShapeMeta() = default;
SymbolicShapeMeta(const SymbolicShapeMeta& other);
void refresh_numel() {
// Non-const, don't need to hold mutables_ lock
available_.fetch_and(~numel_avail);
numel_ = 1;
}
void refresh_contiguous() {
// Non-const, don't need to hold mutables_ lock
available_.fetch_and(numel_avail);
is_contiguous_ = false;
is_channels_last_contiguous_ = false;
is_channels_last_3d_contiguous_ = false;
is_channels_last_ = false;
is_channels_last_3d_ = false;
is_non_overlapping_and_dense_ = false;
}
int64_t dim() const {
return static_cast<int64_t>(sizes_.size());
}
// Accessors for derived quantities, computed lazily on first access
bool has_numel() const {
return available_.load() & numel_avail;
}
bool has_is_contiguous() const {
return available_.load() & is_contiguous_avail;
}
bool has_is_channels_last_contiguous() const {
return available_.load() & is_channels_last_contiguous_avail;
}
bool has_is_channels_last_3d_contiguous() const {
return available_.load() & is_channels_last_3d_contiguous_avail;
}
bool has_is_channels_last() const {
return available_.load() & is_channels_last_avail;
}
bool has_is_channels_last_3d() const {
return available_.load() & is_channels_last_3d_avail;
}
bool has_is_non_overlapping_and_dense() const {
return available_.load() & is_non_overlapping_and_dense_avail;
}
// Accessors to cached derived properties
// DO NOT call with mutables_ lock held
const SymInt& numel() const {
if (C10_UNLIKELY(!has_numel())) {
init_numel();
}
return numel_;
}
const SymBool& is_contiguous() const {
if (C10_UNLIKELY(!has_is_contiguous())) {
init_is_contiguous();
}
return is_contiguous_;
}
const SymBool& is_channels_last_contiguous() const {
if (C10_UNLIKELY(!has_is_channels_last_contiguous())) {
init_is_channels_last_contiguous();
}
return is_channels_last_contiguous_;
}
const SymBool& is_channels_last_3d_contiguous() const {
if (C10_UNLIKELY(!has_is_channels_last_3d_contiguous())) {
init_is_channels_last_3d_contiguous();
}
return is_channels_last_3d_contiguous_;
}
const SymBool& is_channels_last() const {
if (C10_UNLIKELY(!has_is_channels_last())) {
init_is_channels_last();
}
return is_channels_last_;
}
const SymBool& is_channels_last_3d() const {
if (C10_UNLIKELY(!has_is_channels_last_3d())) {
init_is_channels_last_3d();
}
return is_channels_last_3d_;
}
const SymBool& is_non_overlapping_and_dense() const {
if (C10_UNLIKELY(!has_is_non_overlapping_and_dense())) {
init_is_non_overlapping_and_dense();
}
return is_non_overlapping_and_dense_;
}
// Assumptions so we can short-circuit computation
// NOTE: Don't need to lock mutables_ since these aren't const
void assume_contiguous(SymBool val = true) {
is_contiguous_ = std::move(val);
available_.fetch_or(is_contiguous_avail);
}
void assume_channels_last_contiguous(SymBool val = true) {
is_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_contiguous_avail);
}
void assume_channels_last_3d_contiguous(SymBool val = true) {
is_channels_last_3d_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_3d_contiguous_avail);
}
void assume_channels_last(SymBool val = true) {
is_channels_last_ = std::move(val);
available_.fetch_or(is_channels_last_avail);
}
void assume_channels_last_3d(SymBool val = true) {
is_channels_last_3d_ = std::move(val);
available_.fetch_or(is_channels_last_3d_avail);
}
void assume_non_overlapping_and_dense(SymBool val = true) {
is_non_overlapping_and_dense_ = std::move(val);
available_.fetch_or(is_non_overlapping_and_dense_avail);
}
private:
SymBool compute_contiguous() const;
SymBool compute_channels_last_contiguous_2d() const;
SymBool compute_channels_last_contiguous_3d() const;
SymBool compute_strides_like_channels_last_2d() const;
SymBool compute_strides_like_channels_last_3d() const;
SymBool compute_non_overlapping_and_dense() const;
// These are little wrappers over the real compute_ functions that
// can make use of other contiguity fields to short circuit.
// They need to be implemented separately for SymBool, as SymBool does
// not short circuit.
// TODO: should the SymBool cases avoid the short circuit? Need to reason
// if its correct, and reason if the simpler expressions are better for
// analysis (maybe not!)
SymBool compute_channels_last_contiguous_3d_dim5() const;
SymBool compute_channels_last_2d_dim5() const;
SymBool compute_channels_last_3d_dim5() const;
SymBool compute_is_non_overlapping_and_dense_dim4() const;
SymBool compute_is_non_overlapping_and_dense_dim5() const;
SymBool compute_is_non_overlapping_and_dense_anydim() const;
void init_numel() const;
void init_is_contiguous() const;
void init_is_channels_last_contiguous() const;
void init_is_channels_last_3d_contiguous() const;
void init_is_channels_last() const;
void init_is_channels_last_3d() const;
void init_is_non_overlapping_and_dense() const;
// NOTE: These only set if !has_foo()
void set_numel(SymInt val) const;
void set_is_contiguous(SymBool val) const;
void set_is_channels_last_contiguous(SymBool val) const;
void set_is_channels_last_3d_contiguous(SymBool val) const;
void set_is_channels_last(SymBool val) const;
void set_is_channels_last_3d(SymBool val) const;
void set_is_non_overlapping_and_dense(SymBool val) const;
// Lazily initialized variables, with the corresponding available_ flag
// indicating whether the value has been initialized
mutable std::atomic<int> available_{0};
enum avail {
numel_avail = 1 << 0,
is_contiguous_avail = 1 << 1,
is_channels_last_contiguous_avail = 1 << 2,
is_channels_last_3d_contiguous_avail = 1 << 3,
is_channels_last_avail = 1 << 4,
is_channels_last_3d_avail = 1 << 5,
is_non_overlapping_and_dense_avail = 1 << 6,
};
// Mutex to prevent races when initializing the variable from const accessors
mutable std::mutex mutables_;
mutable SymInt numel_ = 1;
mutable SymBool is_contiguous_{true};
mutable SymBool is_channels_last_contiguous_{false};
mutable SymBool is_channels_last_3d_contiguous_{false};
mutable SymBool is_channels_last_{false};
mutable SymBool is_channels_last_3d_{false};
mutable SymBool is_non_overlapping_and_dense_{true};
};
} // namespace c10