@@ -211,6 +211,9 @@ const Tensor& indices) {
211
211
int64_t osizeH = output_size[0 ];
212
212
int64_t osizeW = output_size[1 ];
213
213
214
+ const at::Tensor output_c = output.is_contiguous () ? output : at::empty (output.sizes (), output.options ());
215
+ const at::Tensor indices_c = indices.is_contiguous () ? indices : at::empty (indices.sizes (), indices.options ());
216
+
214
217
if (input.ndimension () == 3 ) {
215
218
int64_t sizeD = input.size (0 );
216
219
int64_t isizeH = input.size (1 );
@@ -223,8 +226,8 @@ const Tensor& indices) {
223
226
AT_DISPATCH_FLOATING_TYPES_AND2 (
224
227
kHalf , kBFloat16 , input.scalar_type (), " adaptive_max_pool2d_cuda" , [&] {
225
228
scalar_t * input_data = input.data_ptr <scalar_t >();
226
- scalar_t * output_data = output .data_ptr <scalar_t >();
227
- int64_t * indices_data = indices .data_ptr <int64_t >();
229
+ scalar_t * output_data = output_c .data_ptr <scalar_t >();
230
+ int64_t * indices_data = indices_c .data_ptr <int64_t >();
228
231
229
232
// cuda blocks & threads:
230
233
int blocksH = (int )(16L / sizeD);
@@ -268,8 +271,8 @@ const Tensor& indices) {
268
271
" adaptive_max_pool2d_cuda" ,
269
272
[&] {
270
273
scalar_t * input_data = input_.data_ptr <scalar_t >();
271
- scalar_t * output_data = output .data_ptr <scalar_t >();
272
- int64_t * indices_data = indices .data_ptr <int64_t >();
274
+ scalar_t * output_data = output_c .data_ptr <scalar_t >();
275
+ int64_t * indices_data = indices_c .data_ptr <int64_t >();
273
276
274
277
// cuda blocks & threads:
275
278
int blocksH = (int )(16L / sizeD);
@@ -296,6 +299,13 @@ const Tensor& indices) {
296
299
C10_CUDA_KERNEL_LAUNCH_CHECK ();
297
300
});
298
301
}
302
+
303
+ if (!output.is_contiguous()) {
304
+ output.copy_ (output_c);
305
+ }
306
+ if (!indices.is_contiguous()) {
307
+ indices.copy_ (indices_c);
308
+ }
299
309
}
300
310
301
311
TORCH_IMPL_FUNC (adaptive_max_pool2d_backward_out_cuda)
@@ -322,7 +332,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
322
332
bool atomic =
323
333
true ; // suboptimal, but without atomic it doesn't pass the tests
324
334
325
- Tensor gradOutput_ = gradOutput.contiguous ();
335
+ const at::Tensor gradOutput_ = gradOutput.contiguous ();
336
+ const at::Tensor indices_ = indices.contiguous ();
337
+ const at::Tensor gradInput_c = gradInput.is_contiguous () ? gradInput : at::empty (gradInput.sizes (), gradInput.options ());
326
338
327
339
if (input.ndimension () == 3 ) {
328
340
int64_t sizeD = input.size (0 );
@@ -334,17 +346,17 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
334
346
335
347
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
336
348
337
- gradInput .zero_ ();
349
+ gradInput_c .zero_ ();
338
350
339
351
AT_DISPATCH_FLOATING_TYPES_AND2 (
340
352
kHalf ,
341
353
kBFloat16 ,
342
354
input.scalar_type (),
343
355
" adaptive_max_pool2d_backward_cuda" ,
344
356
[&] {
345
- scalar_t * gradInput_data = gradInput .data_ptr <scalar_t >();
357
+ scalar_t * gradInput_data = gradInput_c .data_ptr <scalar_t >();
346
358
scalar_t * gradOutput_data = gradOutput_.data_ptr <scalar_t >();
347
- int64_t * indices_data = indices .data_ptr <int64_t >();
359
+ int64_t * indices_data = indices_ .data_ptr <int64_t >();
348
360
349
361
// cuda blocks & threads:
350
362
int blocksH = (int )(16L / sizeD);
@@ -393,7 +405,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
393
405
int64_t osizeH = gradOutput_.size (2 );
394
406
int64_t osizeW = gradOutput_.size (3 );
395
407
396
- gradInput .zero_ ();
408
+ gradInput_c .zero_ ();
397
409
398
410
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
399
411
@@ -403,9 +415,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
403
415
input.scalar_type (),
404
416
" adaptive_max_pool2d_backward_cuda" ,
405
417
[&] {
406
- scalar_t * gradInput_data = gradInput .data_ptr <scalar_t >();
418
+ scalar_t * gradInput_data = gradInput_c .data_ptr <scalar_t >();
407
419
scalar_t * gradOutput_data = gradOutput_.data_ptr <scalar_t >();
408
- int64_t * indices_data = indices .data_ptr <int64_t >();
420
+ int64_t * indices_data = indices_ .data_ptr <int64_t >();
409
421
410
422
// cuda blocks & threads:
411
423
int blocksH = (int )(16L / sizeD);
@@ -446,6 +458,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
446
458
}
447
459
});
448
460
}
461
+
462
+ if (!gradInput.is_contiguous()) {
463
+ gradInput.copy_ (gradInput_c);
464
+ }
449
465
}
450
466
} // at::native
451
467
} // at
0 commit comments