Skip to content

Commit

Permalink
enable deterministic path for index_put with accumulate=False on CPU …
Browse files Browse the repository at this point in the history
…and CUDA (pytorch#57839)

Summary:
Pull Request resolved: pytorch#57839

we reuse the `index_put_accum_kernel`, rename it to  `index_put_deterministic_kernel` and add a bool `accumulate` in `index_backward_kernel`

Test Plan:
buck test mode/opt //caffe2/test:torch -- test_index_put_non_accumulate_deterministic

    ✓ Pass: caffe2/test:torch - test_index_put_non_accumulate_deterministic_cpu (test_torch.TestTorchDeviceTypeCPU) (5.120)
Summary
  Pass: 1
  Skip: 1
    ↻ caffe2/test:torch - test_index_put_non_accumulate_deterministic_meta (test_torch.TestTorchDeviceTypeMETA)
  ListingSuccess: 1

buck test mode/opt //caffe2/test:torch_cuda -- test_index_put_non_accumulate_deterministic

    ✓ ListingSuccess: caffe2/test:torch_cuda - main (6.397)
    ✓ Pass: caffe2/test:torch_cuda - test_index_put_non_accumulate_deterministic_cuda (test_torch.TestTorchDeviceTypeCUDA) (26.030)
    ✓ Pass: caffe2/test:torch_cuda - main (26.030)
Summary
  Pass: 2
  ListingSuccess: 1

Reviewed By: ngimel

Differential Revision: D28290699

fbshipit-source-id: df8bbe7af2e72017566161b05b85737fda4ceb3f
  • Loading branch information
yuguo68 authored and facebook-github-bot committed May 12, 2021
1 parent d623fb7 commit a07a019
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 45 deletions.
13 changes: 4 additions & 9 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ DEFINE_DISPATCH(index_copy_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(index_put_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(index_put_accum_stub);
DEFINE_DISPATCH(index_put_with_sort_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(put_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(take_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(masked_fill_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn);
REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub, index_put_with_sort_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(masked_select_serial_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
Expand Down Expand Up @@ -402,10 +402,10 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List<c10::optional<Tensor>
}
}

if (accumulate && self.device().type() == DeviceType::CUDA) {
if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) {
TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ",
value.device(), " for value tensor");
index_put_accum_stub(self.device().type(), self, indices, value, unsafe);
index_put_with_sort_stub(self.device().type(), self, indices, value, accumulate, unsafe);
return self;
}

Expand Down Expand Up @@ -456,11 +456,6 @@ Tensor take(const Tensor& self, const Tensor& index) {
}

Tensor & index_put_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
if (!accumulate) {
// See note [Writing Nondeterministic Operations]
// Nondeterministic when index contains duplicate entries
at::globalContext().alertNotDeterministic("index_put_ with accumulate=False");
}
return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorAdvancedIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRe
using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
using index_put_accum_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool unsafe);
using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
using put_fn = void(*)(TensorIterator & iter, const Tensor& self, const bool accumulate);
using take_fn = void(*)(TensorIterator & iter, const Tensor& input);
Expand All @@ -37,7 +37,7 @@ DECLARE_DISPATCH(index_fn, index_stub);
DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);
DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
DECLARE_DISPATCH(put_fn, put_stub);
DECLARE_DISPATCH(take_fn, take_stub);
DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cpu/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
// NOTE: duplicate indices are only supported if accumulate is true.
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
iter.dtype(), "index_put", [&] {
// See Note [Enabling Deterministic Operations]
// Parallel cpu_index_kernel with accumulation is nondeterministic, so we
// must enable serial execution if deterministic algorithms are enabled.
const bool is_deterministic = at::globalContext().deterministicAlgorithms();
if (accumulate) {
// See Note [Enabling Deterministic Operations]
// Parallel cpu_index_kernel with accumulation is nondeterministic, so we
// must enable serial execution if deterministic algorithms are enabled.
bool is_deterministic = at::globalContext().deterministicAlgorithms();
bool use_parallel_for = (!is_deterministic) && (
(iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
if (use_parallel_for && iter.dtype() == ScalarType::Float) {
Expand All @@ -252,7 +252,7 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
} else {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
});
}, /*serial_execution=*/is_deterministic);
}
});
}
Expand Down
28 changes: 20 additions & 8 deletions aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace {
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
int64_t* sorted_indices, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) {
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
//numel is total number of flattened indices, not expanded to dimensions that are not indexed.
//stride is the cumulative size of the not-indexed last dimensions
//stride_before is the stride of the dimension immediately preceding first indexed dimension
Expand All @@ -55,6 +55,11 @@ __global__ void indexing_backward_kernel(
&& (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
do {
int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
// if not accumulate, we only keep the last duplicate index so skip those before it
if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
idx++;
continue;
}
const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
const accscalar_t scale = (accscalar_t)1.0;
Expand All @@ -68,13 +73,19 @@ __global__ void indexing_backward_kernel(
int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
if (feature_dim < stride) {
gradient[ii] = static_cast<accscalar_t>(grad_output[grad_row + feature_dim]);
weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
if (accumulate) {
weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
}
}
}

#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
weight[ii] += gradient[ii] * scale;
if (accumulate) {
weight[ii] += gradient[ii] * scale;
} else {
weight[ii] = gradient[ii] * scale;
}
}

#pragma unroll
Expand Down Expand Up @@ -183,7 +194,7 @@ static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t
}


void index_put_accum_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);

namespace {

Expand All @@ -195,7 +206,7 @@ int64_t largestIndex(const Tensor &self) {
return result;
}

void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool unsafe) {
void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
if (indices.size() > (size_t)self.dim()) {
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}
Expand Down Expand Up @@ -224,7 +235,7 @@ void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>
// this bug is fixed in CUDA 11.3
#if defined(CUDA_VERSION) && CUDA_VERSION < 11030
if (num_indices < 50000) {
index_put_accum_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
} else
#endif
{
Expand Down Expand Up @@ -257,7 +268,8 @@ void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>
num_indices,
sliceSize,
strideBefore,
nElemBefore);
nElemBefore,
accumulate);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

Expand All @@ -266,7 +278,7 @@ void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>
}
}

REGISTER_CUDA_DISPATCH(index_put_accum_stub, &index_put_accum_kernel);
REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
} //anonymous


Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace at { namespace native {

void index_put_accum_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
sorted_indices.copy_(linearIndex);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
Expand Down
38 changes: 19 additions & 19 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3966,25 +3966,6 @@ def forward_func(slf, device):
test_func(torch.Tensor.scatter_add)
test_func(torch.scatter_add)

# Ensures that index_put throws nondeterministic alerts in the correct cases
@onlyOnCPUAndCUDA
def test_nondeterministic_alert_index_put(self, device):
def test_func(op_call):
a = torch.randn(10, device=device)
indices = (torch.tensor([0, 0], device=device), )
values = torch.tensor([0, 1], device=device)

@expectedAlertNondeterministic('index_put_ with accumulate=False')
def forward_func(slf, device):
op_call(a, indices, values, accumulate=False)

forward_func(self, device)

test_func(torch.index_put)
test_func(torch.Tensor.index_put)
test_func(torch.index_put_)
test_func(torch.Tensor.index_put_)

@onlyOnCPUAndCUDA
def test_nondeterministic_alert_put(self, device):
def test_func(op_call):
Expand Down Expand Up @@ -5306,6 +5287,25 @@ def test_index_add_deterministic(self, device):
y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)

@onlyOnCPUAndCUDA
def test_index_put_non_accumulate_deterministic(self, device) -> None:
with DeterministicGuard(True):
for i in range(3):
m = random.randint(10, 20)
elems = random.randint(20000, 30000)
values = torch.rand(elems, device=device)
indices = torch.randint(m, (elems,), device=device)
input = torch.rand(m, device=device)
output = input.index_put((indices,), values, accumulate=False)

input_list = input.tolist()
indices_list = indices.tolist()
values_list = values.tolist()
for i, v in zip(indices_list, values_list):
input_list[i] = v

self.assertEqual(output, input_list)

@dtypes(*torch.testing.get_all_dtypes())
def test_index_fill(self, device, dtype):
x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def use_deterministic_algorithms(mode):
* :func:`torch.bmm` when called on sparse-dense CUDA tensors
* :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
and the index is a list of tensors
* :func:`torch.Tensor.index_put` with ``accumulate=False``
* :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
tensor
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
Expand Down Expand Up @@ -415,7 +416,6 @@ def use_deterministic_algorithms(mode):
``mode='max'``
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
* :func:`torch.Tensor.index_copy` when called on a CUDA tensor
* :func:`torch.Tensor.index_put_` when ``accumulate=False``
* :func:`torch.Tensor.put_` when ``accumulate=False``
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
* :func:`torch.histc` when called on a CUDA tensor
Expand Down

0 comments on commit a07a019

Please sign in to comment.