Skip to content

Commit 0ac2986

Browse files
Natalia Gimelsheinpytorchmergebot
authored andcommitted
Fixes softmax indexing for large tensors (pytorch#84182)
Fixes pytorch#84144 Pull Request resolved: pytorch#84182 Approved by: https://github.com/janeyx99
1 parent 533203f commit 0ac2986

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,8 @@ cunn_SoftMaxForward(outscalar_t *output, scalar_t *input, int classes)
636636

637637
// forward pointers to batch[blockIdx.x]
638638
// each block handles a sample in the mini-batch
639-
input += blockIdx.x * classes;
640-
output += blockIdx.x * classes;
639+
input += static_cast<int64_t>(blockIdx.x) * classes;
640+
output += static_cast<int64_t>(blockIdx.x) * classes;
641641

642642
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
643643
const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);
@@ -672,9 +672,9 @@ cunn_SoftMaxBackward(scalar_t *gradInput, outscalar_t *output, outscalar_t *grad
672672

673673
extern __shared__ unsigned char smem[];
674674
auto sdata = reinterpret_cast<accscalar_t*>(smem);
675-
gradInput += blockIdx.x * classes;
676-
output += blockIdx.x * classes;
677-
gradOutput += blockIdx.x * classes;
675+
gradInput += static_cast<int64_t>(blockIdx.x) * classes;
676+
output += static_cast<int64_t>(blockIdx.x) * classes;
677+
gradOutput += static_cast<int64_t>(blockIdx.x) * classes;
678678

679679
const int shift = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
680680
const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t);

test/test_nn.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15803,7 +15803,7 @@ def test_softmax_results(self, device, dtype):
1580315803
@largeTensorTest("20GB")
1580415804
@largeTensorTest("90GB", "cpu")
1580515805
@precisionOverride({torch.half: 0.001})
15806-
def test_softmax_64bit_indexing(self, device, dtype):
15806+
def test_warp_softmax_64bit_indexing(self, device, dtype):
1580715807
def run_test(*shape):
1580815808
x = torch.randn(shape, device="cuda", dtype=torch.float16, requires_grad=True)
1580915809
y = F.log_softmax(x, dim=-1, dtype=dtype)
@@ -15818,6 +15818,22 @@ def run_test(*shape):
1581815818
run_test(1100000000, 2) # Illegal memory access https://github.com/pytorch/pytorch/issues/52715
1581915819
run_test(2200000000, 1) # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716
1582015820

15821+
@onlyCUDA
15822+
@dtypes(torch.half)
15823+
@largeTensorTest("20GB")
15824+
@largeTensorTest("90GB", "cpu")
15825+
@precisionOverride({torch.half: 0.001})
15826+
def test_softmax_64bit_indexing(self, device, dtype):
15827+
def run_test(*shape):
15828+
x = torch.ones(shape, device=device, dtype=dtype, requires_grad=True)
15829+
y = F.log_softmax(x, dim=-1, dtype=dtype)
15830+
y.backward(y)
15831+
self.assertEqual(y[0], y[-1])
15832+
self.assertEqual(x.grad[0], x.grad[-1])
15833+
15834+
run_test(1024 * 256 + 1, 8192) # https://github.com/pytorch/pytorch/issues/84144
15835+
15836+
1582115837
@dtypes(torch.float)
1582215838
@dtypesIfCUDA(torch.float, torch.half)
1582315839
def test_log_softmax_big(self, device, dtype):

0 commit comments

Comments
 (0)