Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 195 additions & 82 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,59 +254,123 @@ int GetVecsize(const std::vector<const DenseTensor *> &ins,
return std::min(out_vec_size, in_vec_size);
}

template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig &config,
int numel,
int num,
int need_broadcast,
int read_lens) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
dst, src, block_offset, config, numel, read_lens);
} else {
kps::ReadData<T, VecSize, 1, IsBoundary>(
dst, src + block_offset, num, read_lens);
#ifndef PADDLE_WITH_XPU_KP
template <typename T,
int VecSize,
int Arity,
bool IsBoundary,
bool is_all_broadcast>
struct BroadcastDataLoader {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<T, VecSize>(args[i], static_cast<T>(1.0f));
if (use_broadcast[i]) {
kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, VecSize);
} else {
kps::ReadData<T, VecSize, 1, IsBoundary>(
args[i], ins[i] + block_offset, num, VecSize);
}
}
}
}
};

template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, true> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
uint32_t index_bc[Arity][VecSize];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
index_bc[j][k] = 0;
args[j][k] = static_cast<T>(1);
}
}

uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
uint32_t idx = thread_offset + k;
if (IsBoundary) {
if (idx == numel) break;
}

#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].kDims) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
}
}
}

#pragma unroll
for (int j = 0; j < Arity; ++j) {
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
args[j][k] = ins[j][index_bc[j][k]];
}
}
}
};
#endif

template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
bool IsBoundary = false>
bool IsBoundary,
bool IsAllBroadcast = false>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
uint32_t numel,
const uint32_t numel,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
int read_lens,
Functor func) {
__simd__ InT args[Arity][VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];

#ifdef PADDLE_WITH_XPU_KP
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i],
read_lens);
if (use_broadcast[i]) {
kps::ReadDataBc<InT, VecSize, 1, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, read_lens);
} else {
kps::ReadData<InT, VecSize, 1, IsBoundary>(
args[i], ins[i] + block_offset, num, read_lens);
}
}
#else
BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, IsAllBroadcast>()(
args, ins, configs, use_broadcast, block_offset, num, numel);
#endif

constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
phi::funcs::ElementwisePrimitiveCaller<InT,
Expand All @@ -321,12 +385,13 @@ __device__ void VectorizedBroadcastKernelImpl(
outs, result, block_offset, num, read_lens);
}

template <typename InT,
template <typename Functor,
typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
int VecSize,
bool IsAllBroadcast>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
Expand All @@ -337,26 +402,26 @@ __global__ void VectorizedBroadcastKernel(
int tail_tid,
int read_lens,
Functor func) {
#ifdef PADDLE_WITH_XPU_KP
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;

#ifdef PADDLE_WITH_XPU_KP
for (; block_offset < main_offset; block_offset += stride) {
VectorizedBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
false>(ins,
outs,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * read_lens,
block_offset,
read_lens,
func);
false,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * read_lens,
block_offset,
read_lens,
func);
}
int num = numel - block_offset;
if (num > 0) {
Expand All @@ -366,49 +431,53 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
true>(ins,
outs,
use_broadcast,
numel,
configs,
num,
block_offset,
read_lens,
func);
true,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
configs,
num,
block_offset,
read_lens,
func);
}
#else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实最开始KP的设想是尽可能不加这种判断,加了之后和写两份Kernel就没区别了。。。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AlphaFold优化起来实在是想不出来其他的优化内容了... 优化这种计算内容简单但性能要求很高的Kernel,就跟在沙漠里面养花一样 T_T

int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
if (block_offset < main_offset) {
VectorizedBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
false>(ins,
outs,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
read_lens,
func);
false,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
configs,
BLOCK_NUM_X * VecSize,
block_offset,
read_lens,
func);
} else {
VectorizedBroadcastKernelImpl<InT,
OutT,
Functor,
Arity,
NumOuts,
VecSize,
true>(ins,
outs,
use_broadcast,
numel,
configs,
tail_tid,
block_offset,
read_lens,
func);
true,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
configs,
tail_tid,
block_offset,
read_lens,
func);
}
#endif
}
Expand All @@ -425,6 +494,7 @@ void LaunchBroadcastKernel(
std::vector<DenseTensor *> *outs,
Functor func,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs) {
int broadcast_num = 0;
int numel = (*outs)[0]->numel();
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
Expand All @@ -435,7 +505,12 @@ void LaunchBroadcastKernel(
}

for (int i = 0; i < Arity; ++i) {
use_broadcast[i] = (ins[i]->numel() != numel);
if (ins[i]->numel() != numel) {
broadcast_num++;
use_broadcast[i] = true;
} else {
use_broadcast[i] = false;
}
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
}

Expand All @@ -446,6 +521,17 @@ void LaunchBroadcastKernel(
auto stream = ctx.x_context()->xpu_stream;
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);

VectorizedBroadcastKernel<Functor, InT, OutT, Arity, NumOuts, VecSize, false>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
#else
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
Expand All @@ -456,17 +542,43 @@ void LaunchBroadcastKernel(
int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) *
read_lens * gpu_config.GetBlockSize();
int tail_tid = numel % (read_lens * gpu_config.GetBlockSize());

if (broadcast_num > (Arity >> 1)) {
VectorizedBroadcastKernel<Functor,
InT,
OutT,
Arity,
NumOuts,
VecSize,
(Arity > 1)>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
} else {
VectorizedBroadcastKernel<Functor,
InT,
OutT,
Arity,
NumOuts,
VecSize,
false>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
}
#endif
VectorizedBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
}

template <ElementwiseType ET,
Expand Down Expand Up @@ -536,6 +648,7 @@ void BroadcastKernelForDifferentVecSize(
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
// if (ins[i]->numel() != (*outs)[0]->numel()) {
if (ins[i]->numel()) {
configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
Expand Down
Loading