forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTensorTransformations.cpp
216 lines (182 loc) · 6.38 KB
/
TensorTransformations.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#include <ATen/native/TensorTransformations.h>
#include <ATen/native/IndexKernel.h> // for flip_stub
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <vector>
namespace at {
namespace native {
Tensor flip(const Tensor& self, IntArrayRef dims) {
const int64_t total_dims = self.dim();
// It wraps the dims and checks that there are no repeated dims
auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
Tensor out_tensor = at::empty_like(self, MemoryFormat::Preserve);
// Count dimensions in which we need to do work
int n = 0;
auto strides = DimVector(self.strides());
for (const auto i : c10::irange(total_dims)) {
if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
n++;
strides[i] = 0;
}
}
// Nothing to do, we return fast
if (n == 0 || self.numel() <=1) {
out_tensor.copy_(self);
return out_tensor;
}
//create dummy output with 0 strides at flipped dimension, to prevent tensorIterator from coalescing flipped dims
const auto restrided_self = self.as_strided(self.sizes(), strides);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.declare_static_dtype_and_device(self.scalar_type(), self.device())
.add_output(out_tensor)
.add_input(self)
.add_input(restrided_self)
.build();
auto* data = reinterpret_cast<char*>(iter.data_ptr(0));
const auto sizes = iter.shape();
// This is a SmallVector of _signed_ ints
auto strides_bytes = DimVector(iter.strides(0));
const auto strides_self = iter.strides(1);
const auto strides_dummy = iter.strides(2);
// To understand this transformation, think of a 3D cube.
// - The data ptr points to the lower-left most vertex of the cube
// - The strides tell us how to move in each dimension,
// that is, data + stride[i] advances one element in the dimension i
// To flip a dimension:
// - We move the pointer to the opposite vertex of the cube
// - We iterate in the opposite direction (invert the strides)
for (const auto i : c10::irange(iter.ndim())) {
// We know that an dimension has a zero stride and self[i] does not, as we defined above
// Note that it may be the case that strides_dummy[i] = 0 not because we set it, but because
// strides_self[i] == 0. We do not want to do anything there
if (strides_dummy[i] == 0 && strides_self[i] != 0) {
data += strides_bytes[i] * (sizes[i]-1);
strides_bytes[i] *= -1;
}
}
iter._unsafe_set_arg_strides(0, strides_bytes);
iter._unsafe_set_arg_data(0, reinterpret_cast<void*>(data));
flip_stub(iter.device_type(), iter, self.is_quantized());
return out_tensor;
}
Tensor roll_cpu(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
if (dims.size() != 1 || shifts.size() != 1) {
return roll_common(self, shifts, dims);
}
// avoid a div zero error below.
if (self.numel() == 0) {
return self.clone(at::MemoryFormat::Preserve);
}
int64_t dim = dims[0];
int64_t size = self.size(dim);
int64_t start = (size - shifts[0]) % size;
// Behavior of % is different in C++ vs Python for negative numbers. This
// corrects the difference.
if (start < 0) {
start = start + size;
}
auto t0 = self.narrow(dim, start, size-start);
auto t1 = self.narrow(dim, 0, start);
return at::cat({t0, t1}, dim);
}
Tensor rot90(const Tensor& self, int64_t k, IntArrayRef dims) {
const int64_t total_dims = self.dim(), total_rot_dims = dims.size();
TORCH_CHECK(total_rot_dims == 2,
"expected total rotation dims == 2, but got dims = ", total_rot_dims);
TORCH_CHECK(total_dims >= 2,
"expected total dims >= 2, but got total dims = ", total_dims);
TORCH_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims,
"expected rotation dims to be different, but got dim0 = ", dims[0],
" and dim1 = ", dims[1]);
// check range of dims
TORCH_CHECK(dims[0] < total_dims && dims[0] >= -total_dims,
"Rotation dim0 out of range, dim0 = ", dims[0]);
TORCH_CHECK(dims[1] < total_dims && dims[1] >= -total_dims,
"Rotation dim1 out of range, dim1 = ", dims[1]);
// handle modulo with negative k
k = (4 + (k % 4)) % 4;
switch(k) {
case 1:
return self.flip({dims[1]}).transpose_(dims[0], dims[1]);
case 2:
return self.flip(dims);
case 3:
return self.flip({dims[0]}).transpose_(dims[0], dims[1]);
default:
return self.clone(at::MemoryFormat::Contiguous);
}
}
Tensor fliplr(const Tensor& self) {
TORCH_CHECK(self.dim() >= 2, "Input must be >= 2-d.");
return self.flip({1});
}
Tensor flipud(const Tensor& self) {
TORCH_CHECK(self.dim() >= 1, "Input must be >= 1-d.");
return self.flip({0});
}
Tensor atleast_1d(const Tensor& self) {
switch (self.dim()) {
case 0:
return self.reshape({1});
default:
return self;
}
}
std::vector<Tensor> atleast_1d(TensorList tensors) {
std::vector<Tensor> result(tensors.size());
auto transform_lambda = [](const Tensor& input) -> Tensor {
return at::native::atleast_1d(input);
};
std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
return result;
}
Tensor atleast_2d(const Tensor& self) {
switch (self.dim()) {
case 0:
return self.reshape({1, 1});
case 1: {
return self.unsqueeze(0);
}
default:
return self;
}
}
std::vector<Tensor> atleast_2d(TensorList tensors) {
std::vector<Tensor> result(tensors.size());
auto transform_lambda = [](const Tensor& input) -> Tensor {
return at::native::atleast_2d(input);
};
std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
return result;
}
Tensor atleast_3d(const Tensor& self) {
switch (self.dim()) {
case 0:
return self.reshape({1, 1, 1});
case 1: {
return self.unsqueeze(0).unsqueeze(-1);
}
case 2: {
return self.unsqueeze(-1);
}
default:
return self;
}
}
std::vector<Tensor> atleast_3d(TensorList tensors) {
std::vector<Tensor> result(tensors.size());
auto transform_lambda = [](const Tensor& input) -> Tensor {
return at::native::atleast_3d(input);
};
std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
return result;
}
DEFINE_DISPATCH(flip_stub);
}} // namespace at::native