forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
porting scatter_add to ATen (CPU) (pytorch#31662)
Summary: Fixes [https://github.com/pytorch/pytorch/issues/24758](https://github.com/pytorch/pytorch/issues/24758). Pull Request resolved: pytorch#31662 Differential Revision: D19440824 Pulled By: ngimel fbshipit-source-id: b13443cfcc8bcb9ec21f1cddb5c6fbc0ef4bb0f2
- Loading branch information
1 parent
5342968
commit 61ee8c9
Showing
9 changed files
with
209 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#pragma once | ||
|
||
#include <vector> | ||
#include <ATen/ATen.h> | ||
#include <ATen/native/DispatchStub.h> | ||
|
||
namespace at { namespace native { | ||
|
||
namespace { | ||
|
||
// Used for `scatter` and `scatter_add` | ||
// Tests: | ||
// 1. index.size(d) <= src.size(d) for all d | ||
// 2. index.size(d) <= self.size(d) for all d != dim | ||
void scatter_shape_check( | ||
const Tensor& self, int64_t dim, | ||
const Tensor& index, const Tensor& src | ||
); | ||
|
||
} // anonymous namespace | ||
|
||
}} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
#include <ATen/native/ScatterGatherShapeChecks.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/Parallel.h> | ||
|
||
namespace at { namespace native { | ||
|
||
namespace { | ||
|
||
static inline int64_t ensure_nonempty_dim(int64_t dim) { | ||
return std::max<int64_t>(dim, 1); | ||
} | ||
|
||
static inline int64_t ensure_nonempty_size(const Tensor& t, int64_t dim) { | ||
return t.dim() == 0 ? 1 : t.size(dim); | ||
} | ||
|
||
static inline int64_t ensure_nonempty_stride(const Tensor& t, int64_t dim) { | ||
return t.dim() == 0 ? 1 : t.stride(dim); | ||
} | ||
|
||
using IdxVec = std::vector<int64_t>; | ||
static inline IdxVec ensure_nonempty_vec(IdxVec vec) { | ||
if (vec.size() == 0) { | ||
vec.push_back(1); | ||
} | ||
return vec; | ||
} | ||
|
||
// Used for `scatter` and `scatter_add` | ||
// Tests: | ||
// 1. index.size(d) <= src.size(d) for all d | ||
// 2. index.size(d) <= self.size(d) for all d != dim | ||
void scatter_shape_check( | ||
const Tensor& self, int64_t dim, | ||
const Tensor& index, const Tensor& src | ||
) { | ||
bool is_wrong_shape = false; | ||
int64_t self_dims = ensure_nonempty_dim(self.dim()); | ||
for (int64_t d = 0; d < self_dims; ++d) { | ||
int64_t index_d_size = ensure_nonempty_size(index, d); | ||
if (index_d_size > ensure_nonempty_size(src, d)) { | ||
is_wrong_shape = true; | ||
break; | ||
} | ||
if (d != dim) { | ||
if (index_d_size > ensure_nonempty_size(self, d)) { | ||
is_wrong_shape = true; | ||
break; | ||
} | ||
} | ||
} | ||
TORCH_CHECK(!is_wrong_shape, | ||
"Expected index [", index.sizes(), "]", | ||
" to be smaller size than src [", src.sizes(), "]", | ||
" and to be smaller than self [", self.sizes(), "]", | ||
" apart from dimension ", dim | ||
); | ||
} | ||
|
||
static Tensor restride_dim( | ||
const Tensor& src, int64_t dim, | ||
IntArrayRef replacement_shape | ||
) { | ||
auto strides = ensure_nonempty_vec(src.strides().vec()); | ||
strides[dim] = 0; | ||
return src.as_strided(replacement_shape, strides); | ||
} | ||
|
||
template <typename func_t> | ||
void cpu_scatter_gather_base_kernel( | ||
Tensor& self, int64_t dim, | ||
const Tensor& index, const Tensor& src, | ||
const std::string& method_name, | ||
const func_t& f, | ||
bool serial_exec = true | ||
) { | ||
auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); | ||
auto index_strides = ensure_nonempty_vec(index.strides().vec()); | ||
|
||
// `dim` is traversed in a kernel function `f`, | ||
// that is why index.stride(dim) = 0 and index.size(dim) = 1. | ||
// Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter | ||
// has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n). | ||
index_sizes[dim] = 1; | ||
index_strides[dim] = 0; | ||
|
||
// set self.shape = src.shape = index.shape, | ||
// this defines the number of elements to iterate over, | ||
// and set self.stride(dim) = src.stride(dim) = 0, | ||
// because `dim` is traversed in a kernel function `f`. | ||
auto self_restrided = restride_dim(self, dim, index_sizes); | ||
auto index_restrided = index.as_strided(index_sizes, index_strides); | ||
auto src_restrided = restride_dim(src, dim, index_sizes); | ||
|
||
auto iter = TensorIterator(); | ||
iter.dont_compute_common_dtype(); | ||
iter.dont_resize_outputs(); | ||
iter.add_output(self_restrided); | ||
iter.add_input(src_restrided, src.device(), src.scalar_type()); | ||
iter.add_input(index_restrided); | ||
iter.build(); | ||
|
||
auto self_dim_stride = ensure_nonempty_stride(self, dim); | ||
auto index_dim_stride = ensure_nonempty_stride(index, dim); | ||
auto src_dim_stride = ensure_nonempty_stride(src, dim); | ||
|
||
AT_DISPATCH_ALL_TYPES_AND2( | ||
ScalarType::Bool, ScalarType::Half, iter.dtype(), | ||
method_name, [&] { | ||
auto loop = [&](char** data, const int64_t* strides, int64_t n) { | ||
auto* self_data_bytes = data[0]; | ||
const auto* index_data_bytes = data[2]; | ||
const auto* src_data_bytes = data[1]; | ||
|
||
for (int64_t i = 0; i < n; ++i) { | ||
f( | ||
(scalar_t*)self_data_bytes, self_dim_stride, | ||
(int64_t*)index_data_bytes, index_dim_stride, | ||
(scalar_t*)src_data_bytes, src_dim_stride | ||
); | ||
|
||
self_data_bytes += strides[0]; | ||
index_data_bytes += strides[2]; | ||
src_data_bytes += strides[1]; | ||
} | ||
}; | ||
if (serial_exec) { | ||
iter.serial_for_each(loop, {0, iter.numel()}); | ||
} else { | ||
iter.for_each(loop); | ||
} | ||
} | ||
); | ||
} | ||
|
||
void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { | ||
if (index.numel() == 0) { | ||
return; | ||
} | ||
|
||
dim = maybe_wrap_dim(dim, self.dim()); | ||
|
||
scatter_shape_check(self, dim, index, src); | ||
|
||
int64_t index_dim_size = ensure_nonempty_size(index, dim); | ||
int64_t self_dim_size = ensure_nonempty_size(self, dim); | ||
|
||
cpu_scatter_gather_base_kernel( | ||
self, dim, index, src, | ||
"scatter_add_", [&] ( | ||
auto* self_data, auto self_dim_stride, | ||
const auto* index_data, auto index_dim_stride, | ||
const auto* src_data, auto src_dim_stride | ||
) { | ||
for (int64_t i = 0; i < index_dim_size; ++i) { | ||
int64_t idx_dim = index_data[i * index_dim_stride]; | ||
TORCH_CHECK(idx_dim >= 0 && idx_dim < self_dim_size, | ||
"index ", idx_dim, | ||
" is out of bounds for dimension ", dim, | ||
" with size ", self_dim_size); | ||
self_data[idx_dim * self_dim_stride] += src_data[i * src_dim_stride]; | ||
} | ||
}, /*serial_exec=*/true | ||
); | ||
} | ||
|
||
} // anonymous napespace | ||
|
||
REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cpu_kernel); | ||
|
||
}} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters