forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMemoryFormat.h
290 lines (271 loc) · 9.18 KB
/
MemoryFormat.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <cstdint>
#include <ostream>
#include <vector>
// Memory format is not the property of a Tensor. It is the way to tell an
// operator how the result should be organized in memory and nothing more. That
// means memory format should never be used as return value for any tensor state
// interrogation functions (internally and externally).
//
// Possible options are:
// Preserve:
// If any of the input tensors is in channels_last format, operator output
// should be in channels_last format
//
// Contiguous:
// Regardless of input tensors format, the output should be contiguous
// Tensor.
//
// ChannelsLast:
// Regardless of input tensors format, the output should be in channels_last
// format.
namespace c10 {
enum class MemoryFormat : int8_t {
Contiguous,
Preserve,
ChannelsLast,
ChannelsLast3d,
NumOptions
};
// If you are seeing this, it means that this call site was not checked if
// the memory format could be preserved, and it was switched to old default
// behaviour of contiguous
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
inline MemoryFormat get_contiguous_memory_format() {
return MemoryFormat::Contiguous;
}
inline std::ostream& operator<<(
std::ostream& stream,
at::MemoryFormat memory_format) {
switch (memory_format) {
case MemoryFormat::Preserve:
return stream << "Preserve";
case MemoryFormat::Contiguous:
return stream << "Contiguous";
case MemoryFormat::ChannelsLast:
return stream << "ChannelsLast";
case MemoryFormat::ChannelsLast3d:
return stream << "ChannelsLast3d";
default:
TORCH_CHECK(false, "Unknown memory format ", memory_format);
}
}
// Note: Hardcoded the channel last stride indices here to get better
// performance
template <typename T>
inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
std::vector<T> strides(sizes.size());
switch (sizes.size()) {
case 4:
strides[1] = 1;
strides[3] = sizes[1];
strides[2] = strides[3] * sizes[3];
strides[0] = strides[2] * sizes[2];
return strides;
case 3:
strides[0] = 1;
strides[2] = sizes[0];
strides[1] = strides[2] * sizes[2];
return strides;
default:
TORCH_INTERNAL_ASSERT(
false, "ChannelsLast2d doesn't support size ", sizes.size());
}
}
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
return get_channels_last_strides_2d<int64_t>(sizes);
}
template <typename T>
std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
std::vector<T> strides(sizes.size());
switch (sizes.size()) {
case 5:
strides[1] = 1;
strides[4] = sizes[1];
strides[3] = strides[4] * sizes[4];
strides[2] = strides[3] * sizes[3];
strides[0] = strides[2] * sizes[2];
return strides;
case 4:
strides[0] = 1;
strides[3] = sizes[0];
strides[2] = strides[3] * sizes[3];
strides[1] = strides[2] * sizes[2];
return strides;
default:
TORCH_INTERNAL_ASSERT(
false, "ChannelsLast3d doesn't support size ", sizes.size());
}
}
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
return get_channels_last_strides_3d<int64_t>(sizes);
}
// NOTE:
// Below are Helper functions for is_channels_last_strides_xd.
// 1. Please do not combine these helper functions, each helper function handles
// exactly one case of sizes + memory_format, by doing this, the strides indices
// will be a constant array and we can access it using constant index number,
// the compiler will fully unroll the loop on strides indices to gain a better
// performance.
// 2. No error check in helper function, caller ensures the correctness of the
// input
// 3. All helper functions have similar comments, only 1st helper function is
// commented here.
template <typename T>
inline bool is_channels_last_strides_2d_s4(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
// special case for trivial C dimension. default to NCHW
if (strides[1] == 0) {
return false;
}
// loop strides indices
for (auto& d : {1, 3, 2, 0}) {
if (sizes[d] == 0) {
return false;
}
if (strides[d] < min) {
return false;
}
// Fallback to NCHW as default layout for ambiguous cases
// This is the flaw of implicit memory_format from strides.
// N111 tensor with identical strides for size 1 dimension;
// Two cases could lead us here:
// a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
// b. N11W contiguous Tensor sliced on the W-dimension.
// ([N,1,1,1]@[W,W,W,W])
if (d == 0 && min == strides[1]) {
return false;
}
// This is necessary to:
// 1. distinguish the memory_format of N1H1;
// [H, 1, 1, 1] channels_last stride
// [H, H, 1, 1] contiguous stride
// 2. permutation of 1C1W:
// [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
// [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
min = strides[d];
if (sizes[d] > 1) {
min *= sizes[d];
}
}
return true;
}
template <typename T>
inline bool is_channels_last_strides_3d_s5(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
if (strides[1] == 0) {
return false;
}
for (auto& d : {1, 4, 3, 2, 0}) {
if (sizes[d] == 0) {
return false;
}
if (strides[d] < min) {
return false;
}
if (d == 0 && min == strides[1]) {
return false;
}
min = strides[d];
if (sizes[d] > 1) {
min *= sizes[d];
}
}
return true;
}
// Note [Ambiguous is_channels_last_strides_xd]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// The flaw of carrying memory_format implicitly through strides is very hard
// to WAR properly. issue #24090
// Without the history of permutation, we can't infer the memory_format of a
// tensor from the snapshot of its size & stride
// e.g.
//
// 1. We can NOT specify the memory_format of N111 tensor through strides in a
// meaningful way;
//
// 2. Two path that ended up with identical size/stride
// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
// the memory_format of the original tensor.
//
// Due to the limitations, our temporary WAR `is_channels_last_strides` does the
// best effort to infer whether the original memory_format of a tensor is
// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
// by their importance):
// 1. Ensure that normal shape manipulation does not accidentally change the
// MemoryFormat of an existing tensor.
// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
//
// The function does so via checking strides of the tensor, including strides of
// size-1 dimensions. Although conventionally PyTorch implies no restriction on
// trivial stride (stride for size-1 dimension).
//
// Note that this approach is a compromise. We did not solve the problem
// completely. Many cases we will not be able to infer the correct memory
// format.
// The implementation of `is_channels_last_strides` is to serve the objectives:
// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
// conversion); Best effort to maintain the ChannelsLast flag.
//
// Due to the fact that this is not a bulletproof solution, through testing
// (aten/src/ATen/test/memory_format_test.cpp)
// a. we ensure that the common tasks are supported;
// a. we identify corner cases where the implementation compromises on.
//
// By the time accumulated permutation is enabled to replace implicit
// memory_format through strides, we should be updating our tests and fix the
// issues in our tests.
//
// We use Channels Last 2d as an example above.
// This is a general problem for all the is_channels_last_strides_xd
// implementation. Please check the helper functions
// (is_channels_last_strides_*d_s*) for more details.
template <typename T>
inline bool is_channels_last_strides_2d(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 4:
return is_channels_last_strides_2d_s4(sizes, strides);
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
template <typename T>
inline bool is_channels_last_strides_3d(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 5:
return is_channels_last_strides_3d_s5(sizes, strides);
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_2d<int64_t>(sizes, strides);
}
inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_3d<int64_t>(sizes, strides);
}
} // namespace c10