forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMaxUnpooling.cpp
316 lines (277 loc) · 9.83 KB
/
MaxUnpooling.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/cpu/MaxUnpoolKernel.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
Tensor& max_unpooling2d_forward_out_cpu(
const Tensor& self_,
const Tensor& indices_,
IntArrayRef output_size,
Tensor& output) {
auto oheight = output_size[0];
auto owidth = output_size[1];
TORCH_CHECK(
indices_.scalar_type() == at::ScalarType::Long,
"elements in indices should be type int64 but got: ", indices_.scalar_type());
TORCH_CHECK(
output_size.size() == 2,
"There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");
TORCH_CHECK(
(self_.ndimension() == 3 || self_.ndimension() == 4),
"Input to max_unpooling2d should be a 3d or 4d Tensor, but got a tensor with ", self_.ndimension(), " dimensions.");
TORCH_CHECK(
self_.sizes() == indices_.sizes(),
"Expected shape of indices to be same as that of the input tensor (", self_.sizes(),
") but got indices tensor with shape: ", indices_.sizes());
for (const auto i : c10::irange(1, self_.ndimension())) {
TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cpu(): ",
"Expected input to have non-zero size for non-batch dimensions, but got ",
self_.sizes(), " with dimension ", i , " being empty.");
}
auto memory_format = self_.suggest_memory_format();
auto self = self_.contiguous(memory_format);
auto indices = indices_.contiguous(memory_format);
if (self.ndimension() == 3) {
int64_t numChannels = self.size(0);
output.resize_({numChannels, oheight, owidth});
} else {
int64_t numBatch = self.size(0);
int64_t numChannels = self.size(1);
output.resize_({numBatch, numChannels, oheight, owidth}, memory_format);
}
output.zero_();
if (output.numel() != 0) {
max_unpool2d_kernel(kCPU, output, self, indices);
}
return output;
};
Tensor max_unpooling2d_forward_cpu(
const Tensor& self,
const Tensor& indices,
IntArrayRef output_size) {
auto output = at::empty({0}, self.options());
at::native::max_unpooling2d_forward_out_cpu(self, indices, output_size, output);
return output;
}
static void max_unpooling3d_shape_check(
const Tensor& input,
const Tensor& gradOutput,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding,
const char *fn_name) {
int64_t oT = output_size[0];
int64_t oH = output_size[1];
int64_t oW = output_size[2];
TORCH_CHECK(
indices.scalar_type() == at::ScalarType::Long,
"elements in indices should be type int64");
TORCH_CHECK(
(input.ndimension() == 4 || input.ndimension() == 5),
"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with ", input.ndimension(), " dimensions.");
TORCH_CHECK(
output_size.size() == 3,
"There should be exactly three elements (depth, height, width) in output_size, but got ", output_size.size(), " elements.");
TORCH_CHECK(
stride.size() == 3,
"There should be exactly three elements (depth, height, width) in stride, but got: ", stride.size(), " elements.");
TORCH_CHECK(
padding.size() == 3,
"There should be exactly three elements (depth, height, width) in padding, but got: ", padding.size(), " elements.");
TORCH_CHECK(
input.sizes() == indices.sizes(),
"Expected shape of indices to be same as that of the input tensor (", input.sizes(),
") but got indices tensor with shape: ", indices.sizes());
for (const auto i : c10::irange(1, input.ndimension())) {
TORCH_CHECK(input.size(i) > 0, fn_name,
": Expected input to have non-zero size for non-batch dimensions, but got ",
input.sizes(), " with dimension ", i , " being empty.");
}
TORCH_CHECK(
stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
"strides should be greater than zero, but got stride: ",
stride);
int dimw = 3;
int dimh = 2;
int dimt = 1;
int dimn = 0;
if (input.ndimension() == 5) {
dimw++;
dimh++;
dimt++;
dimn++;
}
int nslices = input.size(dimn);
if (gradOutput.defined()) {
if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) ||
oW != gradOutput.size(dimw)) {
AT_ERROR(
"Inconsistent gradOutput size. oT= ",
oT,
", oH= ",
oH,
", oW= ",
oW,
". gradOutput: ",
gradOutput.size(dimt),
"x",
gradOutput.size(dimh),
"x",
gradOutput.size(dimw));
}
TORCH_CHECK(
gradOutput.ndimension() == input.ndimension() &&
gradOutput.size(dimn) == nslices,
"gradOutput and input Tensors should have same number of dimensions and also the same number of channels/slices");
}
}
Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
const Tensor& indices_,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding,
Tensor& output) {
TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
int64_t oT = output_size[0];
int64_t oH = output_size[1];
int64_t oW = output_size[2];
auto self = self_.contiguous();
auto indices = indices_.contiguous();
max_unpooling3d_shape_check(
self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()");
if (self_.ndimension() == 5) {
output.resize_({self.size(0), self.size(1), oT, oH, oW});
} else {
output.resize_({self.size(0), oT, oH, oW});
}
output.zero_();
if (output.numel() != 0) {
max_unpool3d_kernel(kCPU, output, self, indices);
}
return output;
}
Tensor max_unpooling3d_forward_cpu(
const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
auto output = at::empty({0}, self.options());
at::native::max_unpooling3d_forward_out_cpu(
self, indices, output_size, stride, padding, output);
return output;
}
Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
const Tensor& self,
const Tensor& indices_,
IntArrayRef output_size,
Tensor& grad_input) {
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
int64_t oheight = output_size[0];
int64_t owidth = output_size[1];
int64_t ndim = self.ndimension();
int64_t dimh = ndim == 3 ? 1 : 2;
int64_t dimw = ndim == 3 ? 2 : 3;
TORCH_CHECK(
indices_.scalar_type() == at::ScalarType::Long,
"elements in indices should be type int64 but got type: ", indices_.scalar_type());
TORCH_CHECK(
self.sizes() == indices_.sizes(),
"Expected shape of indices to be same as that of the input tensor (",
self.sizes(), ") but got indices tensor with shape: ", indices_.sizes());
TORCH_CHECK(output_size.size() == 2, "Output size must be 2 but got: ", output_size.size());
auto memory_format = self.suggest_memory_format();
auto grad_output = grad_output_.contiguous(memory_format);
auto indices = indices_.contiguous(memory_format);
grad_input.resize_(self.sizes(), memory_format);
grad_input.zero_();
if (owidth != grad_output.size(dimw) || oheight != grad_output.size(dimh)) {
AT_ERROR(
"Inconsistent gradOutput size. output height = ",
oheight,
", output width = ",
owidth,
", gradOutput: ",
grad_output.size(dimh),
"x",
grad_output.size(dimw));
}
if (grad_input.numel() != 0) {
max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
}
return grad_input;
}
Tensor max_unpooling2d_backward_cpu(
const Tensor& grad_output,
const Tensor& self,
const Tensor& indices,
IntArrayRef output_size) {
auto grad_input = at::empty({0}, self.options());
max_unpooling2d_backward_out_cpu(
grad_output, self, indices, output_size, grad_input);
return grad_input;
}
Tensor& max_unpooling3d_backward_out_cpu(
const Tensor& grad_output_,
const Tensor& self,
const Tensor& indices_,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding,
Tensor& grad_input) {
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
int64_t oT = output_size[0];
int64_t oH = output_size[1];
int64_t oW = output_size[2];
int64_t ndim = self.ndimension();
int64_t dimt = ndim == 4 ? 1 : 2;
int64_t dimh = ndim == 4 ? 2 : 3;
int64_t dimw = ndim == 4 ? 3 : 4;
max_unpooling3d_shape_check(
self, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cpu()");
/* get contiguous gradOutput */
auto grad_output = grad_output_.contiguous();
auto indices = indices_.contiguous();
/* resize */
grad_input.resize_as_(self);
grad_input.zero_();
if (oW != grad_output.size(dimw) || oH != grad_output.size(dimh) || oT != grad_output.size(dimt)) {
AT_ERROR(
"Inconsistent gradOutput size. output depth = ",
oT,
", output height = ",
oH,
", output width = ",
oW,
", gradOutput: ",
grad_output.size(dimt),
"x",
grad_output.size(dimh),
"x",
grad_output.size(dimw));
}
if (grad_input.numel() != 0) {
max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
}
return grad_input;
}
Tensor max_unpooling3d_backward_cpu(
const Tensor& grad_output,
const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
auto grad_input = at::empty({0}, self.options());
at::native::max_unpooling3d_backward_out_cpu(
grad_output, self, indices, output_size, stride, padding, grad_input);
return grad_input;
}
DEFINE_DISPATCH(max_unpool2d_kernel);
DEFINE_DISPATCH(max_unpool2d_backward_kernel);
DEFINE_DISPATCH(max_unpool3d_kernel);
DEFINE_DISPATCH(max_unpool3d_backward_kernel);
} // namespace native
} // namespace at