Skip to content

Commit

Permalink
fix scatter CPU kernel when (input size, src size) > index size (pyto…
Browse files Browse the repository at this point in the history
…rch#25839)

Summary:
fixes pytorch#25836
According to doc, https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_ `index` must have the smallest size and we should iterate over `index` instead of `tensor`.
cc: dlibenzi
Pull Request resolved: pytorch#25839

Differential Revision: D17269116

Pulled By: ailzhang

fbshipit-source-id: 0e8569fed6c0d2dd70e4e3ec5d29d8730cd2ae8f
  • Loading branch information
Ailing Zhang authored and facebook-github-bot committed Sep 10, 2019
1 parent 5dfef47 commit 26f67e7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
16 changes: 8 additions & 8 deletions aten/src/TH/generic/THTensorApply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,31 @@
}

// Used for `scatter` and `scatterAdd`
// Assumes TENSOR1 is real
// TENSOR2 is src
// TENSOR3 is index
// Assumes TENSOR1 is index
// TENSOR2 is real
// TENSOR3 is src
// Tests:
// 1. index->size(d) <= src->size(d) for all d
// 2. index->size(d) <= real->size(d) for all d != dim
#define TH_TENSOR_DIM_APPLY3_SIZE_SCATTER(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \
{ \
int shape_check_flag = 0; \
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR1); TH_TENSOR_DIM_APPLY_i++) \
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR2); TH_TENSOR_DIM_APPLY_i++) \
{ \
int64_t TENSOR3##_dim_size = THTensor_sizeLegacyNoScalars(TENSOR3, TH_TENSOR_DIM_APPLY_i); \
int64_t TENSOR1##_dim_size = THTensor_sizeLegacyNoScalars(TENSOR1, TH_TENSOR_DIM_APPLY_i); \
if (TH_TENSOR_DIM_APPLY_i != DIMENSION) { \
if (TENSOR3##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR1, TH_TENSOR_DIM_APPLY_i)) { \
if (TENSOR1##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR2, TH_TENSOR_DIM_APPLY_i)) { \
shape_check_flag = 1; \
break; \
} \
} \
if (TENSOR3##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR2, TH_TENSOR_DIM_APPLY_i)) { \
if (TENSOR1##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR3, TH_TENSOR_DIM_APPLY_i)) { \
shape_check_flag = 1; \
break; \
} \
} \
if (shape_check_flag == 1) { \
AT_ERROR("Expected ", #TENSOR3, " ", TENSOR3->sizes(), " to be smaller size than ", #TENSOR2, " ", TENSOR2->sizes(), " and to be smaller than ", #TENSOR1, " ", TENSOR1->sizes(), " apart from dimension ", DIMENSION); \
AT_ERROR("Expected ", #TENSOR1, " ", TENSOR1->sizes(), " to be smaller size than ", #TENSOR3, " ", TENSOR3->sizes(), " and to be smaller than ", #TENSOR2, " ", TENSOR2->sizes(), " apart from dimension ", DIMENSION); \
} \
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor

elems_per_row = THTensor_sizeLegacyNoScalars(index, dim);

TH_TENSOR_DIM_APPLY3(scalar_t, tensor, scalar_t, src, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim,
TH_TENSOR_DIM_APPLY3_SIZE_SCATTER,
for (i = 0; i < elems_per_row; ++i)
{
Expand Down Expand Up @@ -702,7 +702,7 @@ void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTen

elems_per_row = THTensor_sizeLegacyNoScalars(index, dim);

TH_TENSOR_DIM_APPLY3(scalar_t, tensor, scalar_t, src, int64_t, index, dim,
TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim,
TH_TENSOR_DIM_APPLY3_SIZE_SCATTER,
for (i = 0; i < elems_per_row; ++i)
{
Expand Down
22 changes: 22 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8785,6 +8785,28 @@ def test_scatterAdd(self):
def test_scatterFill(self):
self._test_scatter_base(self, lambda t: t, 'scatter_', True)

def test_scatter_to_large_input(self):
for device in torch.testing.get_all_device_types():
input = torch.zeros(4, 4, device=device)
src = torch.ones(2, 2, device=device)
index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
input.scatter_(0, index, src)
self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 0]], device=device))

def test_scatter_add_to_large_input(self):
for device in torch.testing.get_all_device_types():
input = torch.zeros(4, 4, device=device)
src = torch.ones(2, 2, device=device)
index = torch.tensor([[1], [2]], device=device, dtype=torch.long)
input.scatter_add_(0, index, src)
self.assertEqual(input, torch.tensor([[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 0]], device=device))

def test_scatter_bool(self):
for device in torch.testing.get_all_device_types():
x = torch.tensor([[True, True, True], [True, True, True]], device=device)
Expand Down

0 comments on commit 26f67e7

Please sign in to comment.