Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kernel primitive api #3890

Merged
merged 75 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
fafbb84
add kernel primitive api
AnnaTrainingG Sep 17, 2021
997a971
Merge branch 'develop' of https://github.com/PaddlePaddle/docs into P…
AnnaTrainingG Sep 17, 2021
7f0812f
modfied index_cn in guides
AnnaTrainingG Sep 18, 2021
c83cb7b
update
AnnaTrainingG Sep 18, 2021
210ab74
update
AnnaTrainingG Sep 18, 2021
319ba2f
update
AnnaTrainingG Sep 18, 2021
c95af9d
update
AnnaTrainingG Sep 18, 2021
3a6b52a
update
AnnaTrainingG Sep 18, 2021
6012225
update
AnnaTrainingG Sep 18, 2021
21bc6f3
update
AnnaTrainingG Sep 18, 2021
98a2eb7
update
AnnaTrainingG Sep 18, 2021
12cf54f
yes
AnnaTrainingG Sep 22, 2021
66621b9
update
AnnaTrainingG Sep 22, 2021
cbcf3b8
update
AnnaTrainingG Sep 22, 2021
b54eb97
update
AnnaTrainingG Sep 22, 2021
84b6984
add case
AnnaTrainingG Sep 22, 2021
cce555f
update
AnnaTrainingG Sep 22, 2021
8e8ab02
all in io
AnnaTrainingG Sep 22, 2021
da7593c
temp
AnnaTrainingG Sep 22, 2021
e18ec8f
temp update
AnnaTrainingG Sep 22, 2021
d674dd7
temp update
AnnaTrainingG Sep 22, 2021
c21109b
update temp
AnnaTrainingG Sep 22, 2021
cdf48c1
update
AnnaTrainingG Sep 22, 2021
f6d915c
add index_en
AnnaTrainingG Sep 23, 2021
4b4d40d
add en
AnnaTrainingG Sep 23, 2021
01943ef
update en
AnnaTrainingG Sep 23, 2021
4b37f31
update
AnnaTrainingG Sep 23, 2021
57d114d
update
AnnaTrainingG Sep 24, 2021
f7f36e9
update
AnnaTrainingG Sep 24, 2021
b48082c
update
AnnaTrainingG Sep 24, 2021
e2d1e5f
update
AnnaTrainingG Sep 24, 2021
cc30464
upate
AnnaTrainingG Sep 24, 2021
fa50b68
Update index_en.rst
AnnaTrainingG Sep 24, 2021
4d62970
update
AnnaTrainingG Sep 24, 2021
3924dd2
Merge branch 'Primitive_API_31094' of https://github.com/niuliling123…
AnnaTrainingG Sep 24, 2021
84986c6
update
AnnaTrainingG Sep 24, 2021
19d0f5b
update
AnnaTrainingG Sep 24, 2021
cbccb13
update
AnnaTrainingG Sep 24, 2021
caab6d2
update
AnnaTrainingG Sep 24, 2021
fec8e48
add url
AnnaTrainingG Sep 26, 2021
fd8552b
update from ELe to Ele
AnnaTrainingG Sep 27, 2021
2f82c40
update
AnnaTrainingG Sep 27, 2021
4e23256
update Block
AnnaTrainingG Sep 27, 2021
8ca099b
update
AnnaTrainingG Sep 28, 2021
94123a6
add functor
AnnaTrainingG Sep 29, 2021
d6c2ec2
update
AnnaTrainingG Sep 29, 2021
8fbf2cb
update
AnnaTrainingG Sep 29, 2021
50f114b
add images
AnnaTrainingG Sep 29, 2021
6a0433c
update
AnnaTrainingG Sep 29, 2021
b3611e2
temp
AnnaTrainingG Sep 29, 2021
315ab1f
add en
AnnaTrainingG Sep 29, 2021
22ca7e0
update
AnnaTrainingG Sep 29, 2021
898b611
add static_cast
AnnaTrainingG Sep 29, 2021
f782de7
add static_cast
AnnaTrainingG Sep 29, 2021
a185ff0
update
AnnaTrainingG Sep 30, 2021
6f806d9
add functor_en
AnnaTrainingG Sep 30, 2021
32b04d6
update image
AnnaTrainingG Sep 30, 2021
2c788b5
update
AnnaTrainingG Sep 30, 2021
ffbc1a5
update
AnnaTrainingG Sep 30, 2021
16f2cff
add example_reduce.png
AnnaTrainingG Sep 30, 2021
757594e
add example_reduce.png
AnnaTrainingG Sep 30, 2021
0c12dcb
update and add example_add png
AnnaTrainingG Oct 8, 2021
a531027
update
AnnaTrainingG Oct 8, 2021
f65400c
update en
AnnaTrainingG Oct 9, 2021
e6a3d6d
update index and functor
AnnaTrainingG Oct 9, 2021
253bf26
update
AnnaTrainingG Oct 11, 2021
bc68805
update all
AnnaTrainingG Oct 11, 2021
b438f08
update paddlepaddle
AnnaTrainingG Oct 11, 2021
3b9322c
add url for functor
AnnaTrainingG Oct 11, 2021
df442c7
update for merge
AnnaTrainingG Oct 12, 2021
187147f
update functor_en url
AnnaTrainingG Oct 12, 2021
4dfed57
add init in io
AnnaTrainingG Oct 12, 2021
a23984a
update stride_nx in image
AnnaTrainingG Oct 12, 2021
978688d
update ;
AnnaTrainingG Oct 12, 2021
14856dd
update image and notes
AnnaTrainingG Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/guides/07_new_op/index_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

- `自定义Python算子 <./new_python_op_cn.html>`_

- `Kernel Primitives API <./kernel_primitive_api/index_cn.html>`_ : 介绍 PaddlePaddle 为加快算子开发提供的 Block 级 CUDA 函数。

.. toctree::
:hidden:
Expand All @@ -24,3 +25,4 @@
op_notes_cn.md
new_custom_op_cn.md
new_python_op_cn.md
kernel_primitive_api/index_cn.rst
3 changes: 3 additions & 0 deletions docs/guides/07_new_op/index_en.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ This section will guide you how to add an operator, and it also includes some ne

- `op notes <op_notes_en.html>`_ :notes on developing new operators

- `Kernel Primitives API <./kernel_primitive_api/index_en.html>`_ : Introduce the block-level CUDA functions provided by PaddlePaddle to speed up operator development.

.. toctree::
:hidden:

new_op_en.md
op_notes_en.md
kernel_primitive_api/index_en.rst
71 changes: 71 additions & 0 deletions docs/guides/07_new_op/kernel_primitive_api/add_example_cn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# 示例 - ElementwiseAdd
## 功能说明
+ 完成相同 Shape 的两数相加,输入为 InT 类型,输出为 OutT 类型,根据 OpFunc 完成对应的计算。

### OpFunc 定义

```
OpFunc: 用于定义当前数据的计算规则,AddFunctor 定义如下:

template <typename InT>
struct AddFunctor {
HOSTDEVICE InT operator()(const InT &a, const InT &b) const { return (a + b); }
};

```
### Kernel 实现说明

每个线程连续读取 VecSize 个元素,根据剩余元素 num 与 VecSize * blockDim.x 的关系,将数据处理分为 2 部分,第一部分,当 VecSize * blockDim.x > num 表示当前数据处理需要进行边界处理,将 IsBoundary 设置为 true,避免访存越界; 第二部分,不需要进行边界处理,设置 IsBoundary = false。根据当前 block 的数据指针,将数据从全局内存中读取到寄存器中,完成加法操作后,将数据写入全局内存中。注意此处使用 Init 函数对寄存器 arg0,arg1 进行初始化,避免当 arg0 或者 arg1 作为分母时出现为 0 的情况。根据 OpFunc 完成两数求和操作,当需要进行两数相乘,可以直接修改对应的 Functor 即可,可以直接复用 Kernel 代码,提升开发效率。

数据处理过程如下:
![ElementwiseAdd](./images/example_add.png)

### Kernel 代码

```
#include "kernel_primitives/kernel_primitives.h"
template<int VecSize, typename InT, typename OutT, typename OpFunc, bool IsBoundary>
__device__ void ElementwiseAddImpl(InT *in0, InT * in1, OutT * out, OpFunc func, int num) {

InT arg0[VecSize];
InT arg1[VecSize];
OutT result[VecSize];

// init arg0 and arg1
Init<InT, VecSize>(arg0, static_cast<OutT>(1.0f));
Init<InT, VecSize>(arg1, static_cast<OutT>(1.0f));

// read data from global memory
ReadData<InT, InT, VecSize, 1, 1, IsBoundary>(arg0, in0, num);
ReadData<InT, InT, VecSize, 1, 1, IsBoundary>(arg1, in1, num);

// compute resut[i] = args[i] + arg1[i]
ElementwiseBinary<InT, OutT, VecSize, 1, 1, OpFunc>(result, arg0, arg1, func);

// write data
WriteData<OutT, VecSize, 1, 1, IsBoundary>(out, result, num);
}

template<int VecSize, typename InT, typename OutT>
__global__ void ElementwiseAdd(InT *in0, InT *in1, OutT *out, int size) {

// get the data offset of this Block
int data_offset = VecSize * blockIdx.x * blockDim.x;

// get the stride offset the block
int stride = gridDim.x * blockDim.x * VecSize;

for (int offset = data_offset; offset < size; offset += stride) {
if (offset + blockDim.x * VecSize < size) { // set IsBoundary = false

ElementwiseAddImpl<VecSize, InT, OutT, AddFunctor<InT, OutT>, false>(in0 + offset, in1 + offset, out + offset, AddFunctor<InT, OutT>(), size - offset);

} else { // left num is smaller than blockDim.x * VecSize, IsBoundary must be true

ElementwiseAddImpl<VecSize, InT, OutT, AddFunctor<InT, OutT>, true>(in0 + offset, in1 + offset, out + offset, AddFunctor<InT, OutT>(), size - offset);

}
}
}

```
70 changes: 70 additions & 0 deletions docs/guides/07_new_op/kernel_primitive_api/add_example_en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# ElementwiseAdd
## Description
+ To complete the addition of two numbers of the same shape, the input is InT type, and the output is OutT type, and the corresponding calculation is completed according to the OpFunc.

### OpFunc Definition

```
OpFunc : Used to define calculation rules. Addfunctor is defined as follows:

template <typename InT>
struct AddFunctor {
HOSTDEVICE InT operator()(const InT &a, const InT &b) const { return (a + b); }
};

```
### Kernel Description
Each thread reads VecSize elements continuously, and divides the data processing into 2 parts according to the relationship between the remaining elements num and VecSize * blockDim.x. The first part, when VecSize * blockDim.x > num, indicates that the current data processing requires boundary processing, IsBoundary is set to true to avoid memory fetching out of bounds; the second part, no boundary processing is required, set IsBoundary = false. According to the data pointer of the current block, the data is read from the global memory to the register, and after the addition operation, the data is written into the global memory. Note that the Init function is used here to initialize the registers arg0 and arg1 to avoid the situation where arg0 or arg1 is used as the denominator to be 0. The summation of two numbers is completed according to OpFunc. When two numbers need to be multiplied, the Functor can be directly modified. Kernel code can be reused directly to improve development efficiency. </br>
The data processing process of ElementwiseAdd is as follows:</br>
![ElementwiseAdd](./images/example_add.png)

### Code

```

#include "kernel_primitives/kernel_primitives.h"
template<int VecSize, typename InT, typename OutT, typename OpFunc, bool IsBoundary>
__device__ void ElementwiseAddImpl(InT *in0, InT * in1, OutT * out, OpFunc func, int num) {

InT arg0[VecSize];
InT arg1[VecSize];
OutT result[VecSize];

// init arg0 and arg1
Init<InT, VecSize>(arg0, static_cast<OutT>(1.0f));
Init<InT, VecSize>(arg1, static_cast<OutT>(1.0f));

// read data from global memory
ReadData<InT, InT, VecSize, 1, 1, IsBoundary>(arg0, in0, num);
ReadData<InT, InT, VecSize, 1, 1, IsBoundary>(arg1, in1, num);

// compute resut[i] = args[i] + arg1[i]
ElementwiseBinary<InT, OutT, VecSize, 1, 1, OpFunc>(result, arg0, arg1, func);

// write data
WriteData<OutT, VecSize, 1, 1, IsBoundary>(out, result, num);
}

template<int VecSize, typename InT, typename OutT>
__global__ void ElementwiseAdd(InT *in0, InT *in1, OutT *out, int size) {

// get the data offset of this Block
int data_offset = VecSize * blockIdx.x * blockDim.x;

// get the stride offset the block
int stride = gridDim.x * blockDim.x * VecSize;

for (int offset = data_offset; offset < size; offset += stride) {
if (offset + blockDim.x * VecSize < size) { // set IsBoundary = false

ElementwiseAddImpl<VecSize, InT, OutT, AddFunctor<InT, OutT>, false>(in0 + offset, in1 + offset, out + offset, AddFunctor<InT, OutT>(), size - offset);

} else { // left num is smaller than blockDim.x * VecSize, IsBoundary must be true

ElementwiseAddImpl<VecSize, InT, OutT, AddFunctor<InT, OutT>, true>(in0 + offset, in1 + offset, out + offset, AddFunctor<InT, OutT>(), size - offset);

}
}
}

```
13 changes: 13 additions & 0 deletions docs/guides/07_new_op/kernel_primitive_api/api_description_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
API 介绍
##############

- `IO API <./io_api_cn.html>`_ : 介绍 IO 类 API 的定义和功能。
- `Compute API <./compute_api_cn.html>`_ : 介绍 Compute 类 API 的定义和功能。
- `OpFunc <./functor_api_cn.html>`_ : 介绍 Kernel Primitive API 提供的 Functor。

.. toctree::
:hidden:

io_api_cn.md
compute_api_cn.md
functor_api_cn.md
13 changes: 13 additions & 0 deletions docs/guides/07_new_op/kernel_primitive_api/api_description_en.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
API Description
###############

- `IO API <./io_api_en.html>`_ : Describes the definition and functions of IO APIs.
- `Compute API <./compute_api_en.html>`_ : Describes the definition and functions of compute APIs.
- `OpFunc <./functor_api_en.html>`_ : Introduce the Functors provided by the Kernel Primitive API.

.. toctree::
:hidden:

io_api_en.md
compute_api_en.md
functor_api_en.md
175 changes: 175 additions & 0 deletions docs/guides/07_new_op/kernel_primitive_api/compute_api_cn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# API 介绍 - Compute
## [ElementwiseUnary](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/kernel_primitives/compute_primitives.h#L138)
### 函数定义

```
template <typename InT, typename OutT, int NX, int NY, int BlockSize, class OpFunc>
__device__ void ElementwiseUnary(OutT* out, const InT* in, OpFunc compute)
```

### 函数说明

按照 OpFunc 中的计算规则对 in 进行计算,将计算结果按照 OutT 类型存储到寄存器 out 中。

### 模板参数

> InT :输入数据的类型。</br>
> OutT :存储到 out 寄存器中的类型。</br>
> NX :每个线程需要计算 NX 列数据。</br>
> NY :每个线程需要计算 NY 行数据。</br>
> BlockSize :设备属性,标识当前设备线程索引方式。对于 GPU,threadIdx.x 用作线程索引,当前该参数暂不支持。</br>
> OpFunc :计算规则,定义方式请参考 OpFunc 小节。</br>

### 函数参数

> out :输出寄存器指针,大小为 NX * NY。</br>
> in :输入寄存器指针,大小为 NX * NY。</br>
> compute :计算函数,声明为 OpFunc&lt;InT, OutT&gt;()。</br>

## [ElementwiseBinary](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/kernel_primitives/compute_primitives.h#L173)
### 函数定义

```
template <typename InT, typename OutT, int NX, int NY, int BlockSize, class OpFunc>
__device__ void ElementwiseBinary(OutT* out, const InT* in1, const InT* in2, OpFunc compute)
```

### 函数说明

按照 OpFunc 中的计算规则对 in1、in2 进行计算,将计算结果按照 OutT 类型存储到寄存器 out 中。

### 模板参数

> InT :输入数据的类型。</br>
> OutT :存储到 out 寄存器中的类型。</br>
> NX :每个线程需要计算 NX 列数据。</br>
> NY :每个线程需要计算 NY 行数据。</br>
> BlockSize :设备属性,标识当前设备线程索引方式。对于 GPU,threadIdx.x 用作线程索引,当前该参数暂不支持。</br>
> OpFunc :计算规则,定义方式请参考 OpFunc 小节。</br>

### 函数参数

> out :输出寄存器指针,大小为 NX * NY。</br>
> in1 :左操作数寄存器指针,大小为 NX * NY。</br>
> in2 :右操作数寄存器指针,大小为 NX * NY。</br>
> compute :声明为 OpFunc&lt;InT&gt;() 的计算对象。</br>

## [CycleBinary](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/kernel_primitives/compute_primitives.h#L291)

### 函数定义

```
template <typename InT, typename OutT, int NX, int NY, int BlockSize, class OpFunc>
__device__ void CycleBinary(OutT* out, const InT* in1, const InT* in2, OpFunc compute)
```

### 函数说明

按照 OpFunc 中的计算规则对 in1、in2 进行计算,将计算结果按照 OutT 类型存储到寄存器 out 中。in1 的 Shape 为[1, NX],in2 的 Shape 为 [NY, NX],实现 in1,in2 的循环计算,out 的 Shape 是[NY, NX]。

### 模板参数

> InT :输入数据的类型。</br>
> OutT :存储到 out 寄存器中的类型。</br>
> NX :每个线程需要计算 NX 列数据。</br>
> NY :每个线程需要计算 NY 行数据。</br>
> BlockSize :设备属性,标识当前设备线程索引方式。对于 GPU,threadIdx.x 用作线程索引,当前该参数暂不支持。</br>
> OpFunc :计算规则,定义方式请参考 OpFunc 小节。</br>

### 函数参数

> out :输出寄存器指针,大小为 NX * NY。</br>
> in1 :左操作数寄存器指针,大小为 NX。</br>
> in2 :右操作数寄存器指针,大小为 NX * NY。</br>
> compute :声明为 OpFunc&lt;InT&gt;() 的计算对象。</br>

## [ElementwiseTernary](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/kernel_primitives/compute_primitives.h#L210)

### 函数定义

```
template <typename InT, typename OutT, int NX, int NY, int BlockSize, class OpFunc>
__device__ void ElementwiseTernary(OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute)

```

### 函数说明

按照 OpFunc 中的计算规则对 in1、in2、in3 进行计算,将计算结果按照 OutT 类型存储到寄存器 out 中。

### 模板参数

> InT :输入数据的类型。</br>
> OutT :存储到 out 寄存器中的类型。</br>
> NX :每个线程需要计算 NX 列数据。</br>
> NY :每个线程需要计算 NY 行数据。</br>
> BlockSize :设备属性,标识当前设备线程索引方式。对于 GPU,threadIdx.x 用作线程索引,当前该参数暂不支持。</br>
> OpFunc :计算规则,定义方式请参考 OpFunc 小节。</br>

### 函数参数

> out :输出寄存器指针,大小为 NX * NY。</br>
> in1 :操作数 1 的寄存器指针,大小为 NX * NY。</br>
> in2 :操作数 2 的寄存器指针,大小为 NX * NY。</br>
> in3 :操作数 3 的寄存器指针,大小为 NX * NY。</br>
> compute :声明为 OpFunc&lt;InT&gt;() 的计算对象。</br>

## [ElementwiseAny](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/kernel_primitives/compute_primitives.h#L250)

### 函数定义

```
template <typename InT, typename OutT, int NX, int NY, int BlockSize, int Arity, class OpFunc>
__device__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY], OpFunc compute)
```

### 函数说明

按照 OpFunc 中的计算规则对 ins 中的输入进行计算,将计算结果按照 OutT 类型存储到寄存器 out 中,所有输入输出的维度相同。

### 模板参数

> InT :输入数据的类型。</br>
> OutT :存储到 out 寄存器中的类型。</br>
> NX :每个线程需要计算 NX 列数据。</br>
> NY :每个线程需要计算 NY 行数据。</br>
> BlockSize :设备属性,标识当前设备线程索引方式。对于 GPU,threadIdx.x 用作线程索引,当前该参数暂不支持。</br>
> Arity :指针数组 ins 中指针个数。</br>
> OpFunc :计算规则,定义方式请参考 OpFunc 小节。</br>

### 函数参数

> out :输出寄存器指针,大小为 NX * NY。</br>
> ins :由多输入指针构成的指针数组,大小为 Arity。</br>
> compute :声明为 OpFunc&lt;InT&gt;() 的计算对象。</br>

## [Reduce](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/kernel_primitives/compute_primitives.h#L332)

### 函数定义

```
template <typename T, int NX, int NY, int BlockSize, class ReduceFunctor, details::ReduceMode Mode>
__device__ void Reduce(T* out, const T* in, ReduceFunctor reducer, bool reduce_last_dim)
```

### 函数说明

根据 reducer 对 in 中的数据进行数据规约,输入 in 的 Shape 为 [NY, NX],当 Mode = kLocalMode 时,对 in 沿着 NX 方向进行规约,完成线程内规约,out 为[NY, 1];当 Mode = kGlobalMode 时,使用共享内存完成 block 内线程间的规约操作,in 和 out 的 size 相同,均为[NY, NX]。</br>
ReduceMax 数据处理过程如下:</br>
![Reduce](./images/compute_reduce.png)

### 模板参数

> T :输入数据的类型。</br>
> NX :每个线程需要计算 NX 列数据。</br>
> NY :每个线程需要计算 NY 行数据。</br>
> BlockSize :设备属性,标识当前设备线程索引方式。对于 GPU,threadIdx.x 用作线程索引,当前该参数暂不支持。</br>
> ReduceFunctor :计算规则,定义方式请参考 OpFunc 小节。</br>
> Mode :规约模式,可以取值为 kGlobalMode、kLocalMode。

### 函数参数

> out :输出寄存器指针,大小为 NX * NY。</br>
> in :输入寄存器指针,大小为 NX * NY。</br>
> reducer :规约方式,可以使用 ReduceFunctor&lt;T&gt;() 进行定义。</br>
> reduce_last_dim :表示原始输入的最后一维是否进行规约。</br>
Loading