Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move fused_concat op to paddle #115

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ if(WITH_XPU)
SRCS fused_seq_tensor_kernel.h fused_seq_tensor_kernel.kps
DEPS resnet_unit_op)

xpu_library(fused_concat_kernel
SRCS fused_concat_kernel.h fused_concat_kernel.kps
DEPS resnet_unit_op)

register_operators(
DEPS
fused_seqpool_cvm_kernel
fused_seq_tensor_kernel
fused_concat_kernel
EXCLUDES
fused_bn_activation_op
conv_fusion_op
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/operators/fused/fused_concat_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"


namespace paddle {
namespace framework {

template<typename T>
int fused_concat(xpu::Context* ctx,
const std::vector<const T*>& x_list,
T* y,
int batch_size,
int dim_size,
int length,
int offset);

template<typename T>
int fused_concat_grad(xpu::Context* ctx,
const T* dy,
std::vector<T*>& dx_vec,
int batch_size,
int dim_size,
int length,
int offset);

} // namespace framework
} // namespace paddle
212 changes: 212 additions & 0 deletions paddle/fluid/operators/fused/fused_concat_kernel.kps
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#include "xpu/kernel/xtdk.h" // NOLINT
#include "xpu/kernel/xtdk_math.h" // NOLINT
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/refactor/impl_public/wrapper_check.h"

#include "paddle/fluid/platform/device/xpu/xpu_header.h"

#include "paddle/fluid/operators/fused/fused_concat_kernel.h"

namespace paddle {
namespace framework {

template int fused_concat(xpu::Context*,
const std::vector<const float*>&,
float*,
int,
int,
int,
int);

template int fused_concat_grad<float>(xpu::Context* ctx,
const float*,
std::vector<float*>&,
int,
int,
int,
int);

// forward

template <typename T>
__global__ void fused_concat_fwd(unsigned long long* x_addr,
T* y_addr,
int batch_size,
int dim_size,
int x_num,
int length,
int offset) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;

const int buf_size = 512;
__local__ T local_x[buf_size];

int max_num = 256;
__local__ uint64_t x_list[max_num];
const int total_cols = x_num * length;

for (int64_t batch_id = thread_id; batch_id < batch_size; batch_id += nthreads) {
for (int len_offset = 0; len_offset < x_num; len_offset += max_num) {
int len = min(x_num - len_offset, max_num);
// mfence();
GM2LM(x_addr + len_offset, x_list, len * sizeof(uint64_t));

for (int index = 0; index < len; index++) {
__global_ptr__ T* cur_x = reinterpret_cast<__global_ptr__ T*>(x_list[index]);
mfence();

GM2LM_ASYNC(cur_x + batch_id * dim_size, local_x, dim_size * sizeof(T));
mfence();
LM2GM_ASYNC(local_x + offset,
y_addr + batch_id * total_cols + (len_offset + index) * length,
length * sizeof(T));
}
}
}
}

template<typename T>
int fused_concat(xpu::Context* ctx,
const std::vector<const T*>& x_list,
T* y,
int batch_size,
int dim_size,
int length,
int offset) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fused_concat", T);
WRAPPER_DUMP_PARAM2(ctx, x_list, y);
WRAPPER_DUMP_PARAM4(ctx, batch_size, dim_size, length, offset);
WRAPPER_DUMP(ctx);

WRAPPER_ASSERT_GE(ctx, x_list.size(), 1);
WRAPPER_ASSERT_LE(ctx, x_list.size(), INT32_MAX);
WRAPPER_ASSERT_GE(ctx, length, 1);
WRAPPER_ASSERT_GE(ctx, dim_size, length);
WRAPPER_ASSERT_GE(ctx, offset, 0);

int x_num = x_list.size();
for (int i = 0; i < x_num; i++) {
WRAPPER_CHECK_PTR(ctx, T, batch_size * dim_size, x_list[i]);
}
WRAPPER_CHECK_PTR(ctx, T, batch_size * x_num * length, y);

api::ctx_guard RAII_GUARD(ctx);
const T** x_xpu = RAII_GUARD.alloc_l3_or_gm<const T*>(x_num);
WRAPPER_ASSERT_WORKSPACE(ctx, x_xpu);
WRAPPER_ASSERT_SUCCESS(ctx, api::do_host2device(ctx, x_list.data(), x_xpu, x_num * sizeof(T*)));

if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
auto xpu_x_addr_ptr = reinterpret_cast<unsigned long long*>(x_xpu);
int ret = fused_concat_fwd<T><<<ctx->ncluster(), 64, ctx->xpu_stream>>>(xpu_x_addr_ptr,
y,
batch_size,
dim_size,
x_num,
length,
offset);
KERNEL_ASSERT_SUCCESS(ctx, ret);
return api::SUCCESS;
}
WRAPPER_UNIMPLEMENTED(ctx);
}

// backward

template <typename T>
__global__ void fused_concat_bwd(const T* dy_addr,
unsigned long long* dx_addr,
int batch_size,
int dim_size,
int x_num,
int length,
int offset) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = cluster_id() * ncores + cid;
int nthreads = cluster_num() * ncores;

int max_num = 256;
__local__ T local_y_grad[length];
unsigned long long dx_list[max_num];
const int total_cols = x_num * length;

for (int64_t batch_id = thread_id; batch_id < batch_size; batch_id += nthreads) {
for (int len_offset = 0; len_offset < x_num; len_offset += max_num) {
int len = min(x_num - len_offset, max_num);
// mfence();
GM2LM(dx_addr + len_offset, dx_list, len * sizeof(uint64_t));

for (int index = 0; index < len; index++) {
__global_ptr__ T* cur_dx = reinterpret_cast<__global_ptr__ T*>(dx_list[index]);
mfence();

GM2LM_ASYNC(dy_addr + batch_id * total_cols + (len_offset + index) * length,
local_y_grad,
length * sizeof(T));
mfence_lm();
LM2GM_ASYNC(local_y_grad,
cur_dx + batch_id * dim_size + offset,
length * sizeof(T));
}
}
}
}

template<typename T>
int fused_concat_grad(xpu::Context* ctx,
const T* dy,
std::vector<T*>& dx_vec,
int batch_size,
int dim_size,
int length,
int offset) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fused_concat_grad", T);
WRAPPER_DUMP_PARAM2(ctx, dx_vec, dy);
WRAPPER_DUMP_PARAM4(ctx, batch_size, dim_size, length, offset);
WRAPPER_DUMP(ctx);

WRAPPER_ASSERT_GE(ctx, dx_vec.size(), 1);
WRAPPER_ASSERT_LE(ctx, dx_vec.size(), INT32_MAX);
WRAPPER_ASSERT_GE(ctx, length, 1);
WRAPPER_ASSERT_GE(ctx, dim_size, length);
WRAPPER_ASSERT_GE(ctx, offset, 0);

int x_num = dx_vec.size();
WRAPPER_CHECK_PTR(ctx, T, batch_size * x_num * length, dy);
for (int i = 0; i < x_num; i++) {
WRAPPER_CHECK_PTR(ctx, T, batch_size * dim_size, dx_vec[i]);
}

api::ctx_guard RAII_GUARD(ctx);
T** dx_xpu = RAII_GUARD.alloc_l3_or_gm<T*>(x_num);
WRAPPER_ASSERT_WORKSPACE(ctx, dx_xpu);
WRAPPER_ASSERT_SUCCESS(ctx, api::do_host2device(ctx, dx_vec.data(), dx_xpu, x_num * sizeof(T*)));

if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
auto xpu_x_addr_ptr = reinterpret_cast<unsigned long long*>(dx_xpu);
int ret = fused_concat_bwd<T><<<ctx->ncluster(), 64, ctx->xpu_stream>>>(dy,
xpu_x_addr_ptr,
batch_size,
dim_size,
x_num,
length,
offset);
KERNEL_ASSERT_SUCCESS(ctx, ret);
return api::SUCCESS;
}
WRAPPER_UNIMPLEMENTED(ctx);
}

} // namespace framework
} // namespace paddle
7 changes: 5 additions & 2 deletions paddle/fluid/operators/fused/fused_concat_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fused_concat_op.h"
#include "paddle/fluid/operators/fused/fused_concat_kernel.h"
#ifdef PADDLE_WITH_BOX_PS
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#else
Expand Down Expand Up @@ -82,7 +83,8 @@ class FusedConcatOpXPUKernel : public framework::OpKernel<T> {
#ifdef TRACE_PROFILE
TRACE_SCOPE_START("xpu::fused_concat", xpu_wait(xpu_context->xpu_stream););
#endif
int r = xpu::fused_concat<T>(xpu_context,
// int r = xpu::fused_concat<T>(xpu_context,
int r = paddle::framework::fused_concat<T>(xpu_context,
cpu_x_addr_vec,
cpu_y_addr,
batch_size,
Expand Down Expand Up @@ -132,7 +134,8 @@ class FusedConcatGradOpXPUKernel : public framework::OpKernel<T> {

// output
auto cpu_dy_addr = out_grad->data<T>();
int r = xpu::fused_concat_grad<T>(xpu_context,
// int r = xpu::fused_concat_grad<T>(xpu_context,
int r = paddle::framework::fused_concat_grad<T>(xpu_context,
cpu_dy_addr,
cpu_dx_list,
batch_size,
Expand Down