Skip to content

Commit 302ee7b

Browse files
seemetherexwang233
andauthored
[release/1.10] Fix adaptive_max_pool2d for channels-last on CUDA (pytorch#67697) (pytorch#69618)
Co-authored-by: Xiao Wang <24860335+xwang233@users.noreply.github.com>
1 parent 0c91a70 commit 302ee7b

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ const Tensor& indices) {
211211
int64_t osizeH = output_size[0];
212212
int64_t osizeW = output_size[1];
213213

214+
const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
215+
const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options());
216+
214217
if (input.ndimension() == 3) {
215218
int64_t sizeD = input.size(0);
216219
int64_t isizeH = input.size(1);
@@ -223,8 +226,8 @@ const Tensor& indices) {
223226
AT_DISPATCH_FLOATING_TYPES_AND2(
224227
kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] {
225228
scalar_t* input_data = input.data_ptr<scalar_t>();
226-
scalar_t* output_data = output.data_ptr<scalar_t>();
227-
int64_t* indices_data = indices.data_ptr<int64_t>();
229+
scalar_t* output_data = output_c.data_ptr<scalar_t>();
230+
int64_t* indices_data = indices_c.data_ptr<int64_t>();
228231

229232
// cuda blocks & threads:
230233
int blocksH = (int)(16L / sizeD);
@@ -268,8 +271,8 @@ const Tensor& indices) {
268271
"adaptive_max_pool2d_cuda",
269272
[&] {
270273
scalar_t* input_data = input_.data_ptr<scalar_t>();
271-
scalar_t* output_data = output.data_ptr<scalar_t>();
272-
int64_t* indices_data = indices.data_ptr<int64_t>();
274+
scalar_t* output_data = output_c.data_ptr<scalar_t>();
275+
int64_t* indices_data = indices_c.data_ptr<int64_t>();
273276
274277
// cuda blocks & threads:
275278
int blocksH = (int)(16L / sizeD);
@@ -296,6 +299,13 @@ const Tensor& indices) {
296299
C10_CUDA_KERNEL_LAUNCH_CHECK();
297300
});
298301
}
302+
303+
if (!output.is_contiguous()) {
304+
output.copy_(output_c);
305+
}
306+
if (!indices.is_contiguous()) {
307+
indices.copy_(indices_c);
308+
}
299309
}
300310
301311
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
@@ -322,7 +332,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
322332
bool atomic =
323333
true; // suboptimal, but without atomic it doesn't pass the tests
324334
325-
Tensor gradOutput_ = gradOutput.contiguous();
335+
const at::Tensor gradOutput_ = gradOutput.contiguous();
336+
const at::Tensor indices_ = indices.contiguous();
337+
const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options());
326338
327339
if (input.ndimension() == 3) {
328340
int64_t sizeD = input.size(0);
@@ -334,17 +346,17 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
334346
335347
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
336348
337-
gradInput.zero_();
349+
gradInput_c.zero_();
338350
339351
AT_DISPATCH_FLOATING_TYPES_AND2(
340352
kHalf,
341353
kBFloat16,
342354
input.scalar_type(),
343355
"adaptive_max_pool2d_backward_cuda",
344356
[&] {
345-
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
357+
scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>();
346358
scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>();
347-
int64_t* indices_data = indices.data_ptr<int64_t>();
359+
int64_t* indices_data = indices_.data_ptr<int64_t>();
348360
349361
// cuda blocks & threads:
350362
int blocksH = (int)(16L / sizeD);
@@ -393,7 +405,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
393405
int64_t osizeH = gradOutput_.size(2);
394406
int64_t osizeW = gradOutput_.size(3);
395407
396-
gradInput.zero_();
408+
gradInput_c.zero_();
397409
398410
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
399411
@@ -403,9 +415,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
403415
input.scalar_type(),
404416
"adaptive_max_pool2d_backward_cuda",
405417
[&] {
406-
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
418+
scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>();
407419
scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>();
408-
int64_t* indices_data = indices.data_ptr<int64_t>();
420+
int64_t* indices_data = indices_.data_ptr<int64_t>();
409421
410422
// cuda blocks & threads:
411423
int blocksH = (int)(16L / sizeD);
@@ -446,6 +458,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
446458
}
447459
});
448460
}
461+
462+
if (!gradInput.is_contiguous()) {
463+
gradInput.copy_(gradInput_c);
464+
}
449465
}
450466
} // at::native
451467
} // at

test/test_nn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14622,7 +14622,6 @@ def test_upsamplingBilinear2d(self, device):
1462214622

1462314623
self.assertEqual(a_cuda.grad, a_cpu.grad)
1462414624

14625-
@onlyCPU
1462614625
@dtypes(torch.float, torch.double)
1462714626
def test_adaptive_pooling_max_nhwc(self, device, dtype):
1462814627
def helper(n, c, h, w, output_height, output_width, contig):

0 commit comments

Comments
 (0)