Skip to content

Commit

Permalink
Add API and op for take_along_axis (PaddlePaddle#38396)
Browse files Browse the repository at this point in the history
* add API and op for take_along_axis

* fix compile dependency problem and add example code and doc

* add unitest

* delete some code for CI coverage

* fix code style problem

* fix as review
  • Loading branch information
huangxu96 authored Dec 28, 2021
1 parent 6f1bb3d commit 3310f51
Show file tree
Hide file tree
Showing 11 changed files with 881 additions and 1 deletion.
8 changes: 7 additions & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()

set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils)
if (WITH_ROCM)
hip_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif()

set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils gather_scatter_kernel)

register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
Expand Down
148 changes: 148 additions & 0 deletions paddle/fluid/operators/gather_scatter_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/gather_scatter_kernel.h"
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class TensorAssign {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data = *src_data;
}
};
static TensorAssign tensor_assign;

class ReduceAdd {
public:
template <typename tensor_t>
void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data += *src_data;
}
};

static ReduceAdd reduce_add;

template <typename tensor_t, typename index_t = int64_t,
bool is_scatter_like = true>
struct cpu_gather_scatter_functor {
template <typename func_t>
void operator()(Tensor self, int dim, const Tensor& index, const Tensor& src,
const std::string& method_name, const func_t& reduce_op,
const platform::DeviceContext& ctx) {
if (index.numel() == 0) {
return;
}
auto* self_data = self.data<tensor_t>();
auto* index_data = index.data<index_t>();
auto* src_data = src.data<tensor_t>();
int64_t self_size = self.numel();
int64_t index_size = index.numel();
int64_t src_size = src.numel();
auto self_dims = self.dims();
auto index_dims = index.dims();
auto src_dims = src.dims();
if (self_size == 0 || src_size == 0 || index_size == 0) {
VLOG(3) << "zero size input found";
platform::errors::InvalidArgument(
"self_size, src_size, index_size cannot be 0");
return;
}
int select_dim_size = index_dims[dim];
// index matrix has different shape with self matrix or src matrix.
int replaced_select_dim_size =
is_scatter_like ? self_dims[dim] : src_dims[dim];
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t index_idx = 0;
int64_t self_idx, src_idx;

// N layer loop squeezed into 3 layers loop
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];

/*
gather computation formula:
self[i][j][k] = src[index[i][j][k]][j][k] # if dim == 0
self[i][j][k] = src[i][index[i][j][k]][k] # if dim == 1
self[i][j][k] = src[i][j][index[i][j][k]] # if dim == 2
scatter computation formula:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
*/

// This index might out of bound of index matrix's index, so here
// multiply the replaced_select_dim_size.
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * replaced_select_dim_size;

self_idx = is_scatter_like ? replace_index : index_idx;
src_idx = is_scatter_like ? index_idx : replace_index;

reduce_op((tensor_t*)(self_data + self_idx),
(tensor_t*)(src_data + src_idx));
index_idx++;
}
}
}
}
};

template <typename tensor_t, typename index_t>
void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/false>()(
result, dim, index, self, "gather_out_cpu", tensor_assign, ctx);
}

template <typename tensor_t, typename index_t>
void cpu_scatter_assign_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_assign_cpu", tensor_assign, ctx);
}

template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
cpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_add_cpu", reduce_add, ctx);
}

Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)

} // namespace operators
} // namespace paddle
157 changes: 157 additions & 0 deletions paddle/fluid/operators/gather_scatter_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class TensorAssign {
public:
template <typename tensor_t>
constexpr void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data = *src_data;
}
};
static TensorAssign tensor_assign;

class ReduceAdd {
public:
template <
typename tensor_t,
std::enable_if_t<!std::is_same<tensor_t, uint8_t>::value>* = nullptr>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
platform::CudaAtomicAdd(self_data, *src_data);
}
template <typename tensor_t,
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
__device__ void operator()(tensor_t* self_data, tensor_t* src_data) const {
*self_data += *src_data;
}
};
static ReduceAdd reduce_add;

template <typename tensor_t, typename index_t, typename func_t,
bool is_scatter_like = true>
__global__ void GatherScatterGPUKernel(
tensor_t* self_data, int dim, const index_t* index_data, tensor_t* src_data,
int64_t inner_dim_size, int select_dim_size, int replaced_select_dim_size,
int64_t outer_dim_size, int64_t numel, const func_t& reduce_op) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
// squeezed from the N layers loop.
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
j = remind / outer_dim_size;
k = remind % outer_dim_size;
index_t index = index_data[tid];
/*
gather computation formula:
self[i][j][k] = src[index[i][j][k]][j][k] # if dim == 0
self[i][j][k] = src[i][index[i][j][k]][k] # if dim == 1
self[i][j][k] = src[i][j][index[i][j][k]] # if dim == 2
scatter computation formula:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
*/
// index matrix has different shape with self matrix or src matrix.
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * replaced_select_dim_size;
int64_t self_idx = is_scatter_like ? replace_index : tid;
int64_t src_idx = is_scatter_like ? tid : replace_index;
reduce_op((tensor_t*)(self_data + self_idx), (tensor_t*)(src_data + src_idx));
}

template <typename tensor_t, typename index_t = int64_t,
bool is_scatter_like = true>
struct gpu_gather_scatter_functor {
template <typename func_t>
void operator()(Tensor self, int dim, const Tensor& index, Tensor src,
const std::string& method_name, const func_t& reduce_op,
const platform::DeviceContext& ctx) {
if (index.numel() == 0) {
return;
}
auto* self_data = self.data<tensor_t>();
auto* index_data = index.data<index_t>();
auto* src_data = src.data<tensor_t>();
int64_t self_size = self.numel();
int64_t index_size = index.numel();
int64_t src_size = src.numel();
auto self_dims = self.dims();
auto index_dims = index.dims();
auto src_dims = src.dims();
if (self_size == 0 || src_size == 0 || index_size == 0) return;
int select_dim_size = index_dims[dim];
// index matrix has different shape with self matrix or src matrix.
int replaced_select_dim_size =
is_scatter_like ? self_dims[dim] : src_dims[dim];
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int64_t i = 0; i < index_dims.size(); ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
}

int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
GatherScatterGPUKernel<tensor_t, index_t, func_t,
is_scatter_like><<<grid, block, 0, stream>>>(
self_data, dim, index_data, src_data, inner_dim_size, select_dim_size,
replaced_select_dim_size, outer_dim_size, index_size, reduce_op);
}
}; // struct gpu_gather_scatter_functor

template <typename tensor_t, typename index_t>
void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/false>()(
result, dim, index, self, "gather_out_gpu", tensor_assign, ctx);
return;
}

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx) {
gpu_gather_scatter_functor<tensor_t, index_t,
/*is_scatter_like=*/true>()(
self, dim, index, src, "scatter_add_gpu", reduce_add, ctx);
}

namespace plat = paddle::platform;
Instantiate_Template_Function(gpu_gather_kernel)
Instantiate_Template_Function(gpu_scatter_add_kernel)

} // namespace operators
} // namespace paddle
57 changes: 57 additions & 0 deletions paddle/fluid/operators/gather_scatter_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/tensor.h"

#pragma once

namespace paddle {
namespace operators {

#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, platform::float16) \
Instantiate_Template_Function_index_t(func, unsigned char)

#define Instantiate_Template_Function_index_t(func, tensor_t) \
template void func<tensor_t, int>(Tensor input, int dim, \
const Tensor& index, Tensor result, \
const platform::DeviceContext& ctx); \
template void func<tensor_t, int64_t>(Tensor input, int dim, \
const Tensor& index, Tensor result, \
const platform::DeviceContext& ctx);

using Tensor = framework::Tensor;

template <typename tensor_t, typename index_t>
void cpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void cpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_gather_kernel(Tensor self, int dim, const Tensor& index, Tensor result,
const platform::DeviceContext& ctx);

template <typename tensor_t, typename index_t>
void gpu_scatter_add_kernel(Tensor self, int dim, const Tensor& index,
Tensor src, const platform::DeviceContext& ctx);

} // namespace operators
} // namespace paddle
Loading

0 comments on commit 3310f51

Please sign in to comment.