forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDispatchKeySet.cpp
297 lines (274 loc) · 12.5 KB
/
DispatchKeySet.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#include <c10/core/DispatchKeySet.h>
#include <c10/util/irange.h>
namespace c10 {
// backend_dispatch_keyset includes all dispatch keys that map to backends.
// Alias key DispatchKey::CompositeExplicitAutograd maps to
// backend_dispatch_keyset
constexpr DispatchKeySet backend_dispatch_keyset =
autogradother_backends | DispatchKeySet(DispatchKey::Dense);
// See Note [CompositeExplicitAutogradNonFunctional Key]
// We have several types of decompositions in aten, that each have their own
// alias key. You should register your decomposition to the
// `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op
// (2) It decomposes into one more mutation ops
// (3) It has a derivative formula
// (In theory we could also have a separate key for
// "CompositeImplicitAutogradNonFunctional", but there isn't much of a use
// case for it currently).
// This key is important for "functional" backends like LazyTensor / XLA.
// If you're a backend that only expects to deal with "functional ops",
// then you don't want to decompose a functional op into an op that causes
// aliasing. You should just directly write a kernel for that functional op
// instead!
constexpr DispatchKeySet non_functional_backend_dispatch_keyset =
backend_dispatch_keyset
// XLA and LazyTensor are currently the only 2 backends in core
// that use functionalization pass in eager mode.
.remove(DispatchKey::Sparse)
.remove_backend(BackendComponent::XLABit)
.remove_backend(BackendComponent::LazyBit);
bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined
// See Note [No Alias Keys in DispatchKeySet]
&& !isAliasDispatchKey(t)
// Note [NestedTensor Not Included in Backend Keys]
// NestedTensor has been explicitly removed from the "backend keyset" due
// to incompatibility with some kernels, so we don't want it to be
// included in CompositeExplicitAutograd kernels.
&& t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
// maps to [math_dispatch_keyset x full_backend_mask]
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
autograd_dispatch_keyset |
// See Note [NestedTensor Not Included in Backend Keys]
// The caveat to that note is that nested_tensor is a special case
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always re-use CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
// See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// That's why we OR it with a mask of the backend bits here.
// getRuntimeDispatchKeySet() expects to return a keyset of runtime
// dispatch keys, like AutogradCPU, but that requires having backend bits.
return autograd_dispatch_keyset |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeImplicitAutogradNestedTensor:
return nested_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset;
case DispatchKey::CompositeExplicitAutogradNonFunctional:
return non_functional_backend_dispatch_keyset;
default:
return DispatchKeySet(t);
}
}
bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset.has(toFunctionalityKey(k));
case DispatchKey::CompositeImplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return math_dispatch_keyset.has(k);
case DispatchKey::CompositeImplicitAutogradNestedTensor:
// See Note [NestedTensor Not Included in Backend Keys]
return nested_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutogradNonFunctional:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor &&
non_functional_backend_dispatch_keyset.has(k);
case DispatchKey::FuncTorchBatchedDecomposition:
return functorch_batched_ks.has(k);
default:
return t == k;
}
}
// for a given autograd key, return the (guaranteed nonempty) set of associated
// backend keys. for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
return DispatchKeySet(DispatchKey::CPU);
case DispatchKey::AutogradCUDA:
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradLazy:
return DispatchKeySet(DispatchKey::Lazy);
case DispatchKey::AutogradMeta:
return DispatchKeySet(DispatchKey::Meta);
case DispatchKey::AutogradMPS:
return DispatchKeySet(DispatchKey::MPS);
case DispatchKey::AutogradHPU:
return DispatchKeySet(DispatchKey::HPU);
case DispatchKey::AutogradIPU:
return DispatchKeySet(DispatchKey::IPU);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:
return DispatchKeySet(DispatchKey::PrivateUse1);
case DispatchKey::AutogradPrivateUse2:
return DispatchKeySet(DispatchKey::PrivateUse2);
case DispatchKey::AutogradPrivateUse3:
return DispatchKeySet(DispatchKey::PrivateUse3);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
case DispatchKey::AutogradOther:
return autogradother_backends;
default:
return DispatchKeySet();
}
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
}
std::string toString(DispatchKeySet ts) {
std::stringstream ss;
ss << ts;
return ss.str();
}
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
if (ts.empty()) {
os << "DispatchKeySet()";
return os;
}
os << "DispatchKeySet(";
bool first = true;
for (auto k : ts) {
if (!first) {
os << ", ";
}
os << k;
first = false;
}
os << ")";
return os;
}
DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_);
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_functionality_bits =
llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
uint64_t masked_backend_bits =
llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
*data_ptr_;
uint64_t first_functionality_idx =
llvm::findFirstSet(masked_functionality_bits);
uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
// If there are no keys, set to end iterator value
if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
next_functionality_ == iterator::end_iter_mask_val) {
// Set up state to be the same as end()
next_functionality_ = iterator::end_iter_mask_val;
current_dispatchkey_idx_ = iterator::end_iter_key_val;
next_backend_ = 0;
current_backendcomponent_idx_ = iterator::end_iter_key_val;
return *this;
}
// The +1 is because of DispatchKey::Undefined and
// BackendComponent::InvalidBit
auto new_next_functionality = first_functionality_idx + 1;
auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
// and the -num_backends is because the first <num_backends> bits in the
// keyset are not Dispatch Keys.
auto next_dispatchkey_idx = new_next_functionality - num_backends;
// If the current functionality bit is a per-backend bit, we need special
// handling
if (isPerBackendFunctionalityKey(
static_cast<DispatchKey>(next_dispatchkey_idx))) {
// case 1: if the current backend is undefined, then there is no valid
// backend instance of this functionality key so we can skip it.
if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// increment the functionality mask so we skip the current functionality
// bit on the next increment.
next_functionality_ = new_next_functionality;
++(*this);
return *this;
}
// Otherwise, at this point we know what the current backend and
// functionality bits are.
current_dispatchkey_idx_ = next_dispatchkey_idx;
current_backendcomponent_idx_ = new_backendcomponent_idx;
// Next, we need to set up the masks for the next increment.
uint64_t next_backendcomponent_bits =
llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
full_backend_mask & *data_ptr_;
uint64_t next_backendcomponent_idx =
llvm::findFirstSet(next_backendcomponent_bits);
if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// case 2: the current backend is valid, but there is not another backend
// in the keyset. In this case, we need to bump the functionality mask and
// reset the backend mask for the next increment
next_functionality_ = new_next_functionality;
next_backend_ = 0;
} else {
// case 3: we have another backend to iterate over. We want to iterate
// over the same functionality bit next time, but a different backend bit.
next_backend_ = first_backendcomponent_idx + 1;
}
} else {
// Functionality bits that aren't per backend are simpler to handle. We can
// ignore the backend bits.
TORCH_INTERNAL_ASSERT(next_backend_ == 0);
current_dispatchkey_idx_ = next_dispatchkey_idx;
next_functionality_ = new_next_functionality;
}
return *this;
}
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks() {
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
offsets_and_masks;
// manually set the first entry, which corresponds to Undefined.
offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
// loop through every functionality key (aside from Undefined).
for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
// functionality_idx should be Dense -> 1, ...
auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
auto k = static_cast<DispatchKey>(functionality_idx);
// If the previous functionality was not per-backend, then we can just
// increment the previous offset. Otherwise, the next offset =
// previous_offset + num_backends.
auto next_offset = prev_offset_and_mask.offset +
(prev_offset_and_mask.mask == 0 ? 1 : num_backends);
// the mask is used in the runtime index calculation to find the offset of
// the backend. For non-per-backend functionalities, this offset should
// always be 0. Otherwise, we need to get the index of the backend (which we
// can do using a backend mask).
auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
offsets_and_masks[functionality_idx] =
FunctionalityOffsetAndMask(next_offset, next_mask);
}
// Sanity check that the computed offset index of the last functionality key
// is correct. This assumes that the highest priority functionality key is not
// per backend.
TORCH_INTERNAL_ASSERT(
offsets_and_masks[num_functionality_keys - 1].offset ==
(num_runtime_entries - 1),
"num_runtime_entries: ",
num_runtime_entries,
"last_offset: ",
offsets_and_masks[num_functionality_keys - 1].offset);
return offsets_and_masks;
}
} // namespace c10