Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,29 @@ struct BroadcastDataLoader {
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
#ifdef PADDLE_WITH_XPU_KP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

89行的vec_size只与out有关吗? 看你修改前的代码是与in/out同时有关的,不确定这里会不会隐藏性能问题

Copy link
Contributor Author

@zhangboSJTU zhangboSJTU May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考elementwise也是只取了out的vec_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elementwise是因为dim是相同的,而broadcast 输入输出的dim可能是不同的……

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vec_size 原本是取 min( in out 4), 现在是取min( out 4),那应该值是>=之前的值,所以应该不会造成性能下降,有其他原因考虑需要加上吗

kps::Init<Type, ArgsT, Index, VecSize>(
args, static_cast<Type>(1.0f), read_lens);
if (use_broadcast[Index]) {
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]),
block_offset,
configs[Index],
numel,
read_lens);
} else {
kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
num,
read_lens);
}
#else
kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));

if (use_broadcast[Index]) {
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
args,
Expand All @@ -133,6 +152,7 @@ struct BroadcastDataLoader {
num,
VecSize);
}
#endif
}
};

Expand All @@ -148,7 +168,8 @@ struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
Expand All @@ -173,7 +194,8 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
const Array3 &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
const uint32_t numel,
int read_lens) {
using Type = std::tuple_element_t<Index, ArgsT>;
using VecType = phi::kps::details::VectorType<Type, VecSize>;
VecType vec_temp;
Expand Down Expand Up @@ -269,6 +291,10 @@ __device__ void VectorizedBroadcastKernelImpl(
__simd__ ArgsT args[VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];

#ifdef PADDLE_WITH_XPU_KP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要单独区分kp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPUKP 在broadcast的功能与GPU是一样的呀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里之前铭书对GPU的broadcast进行了特化优化(减少了其中重复的fast_divmod计算),这里为了保持其优化效果,就需要单独拿出来

BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
#else
if (LoadType == kBroadcast) {
uint32_t index_bc[Arity][VecSize] = {0};
Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
Expand All @@ -291,9 +317,9 @@ __device__ void VectorizedBroadcastKernelImpl(
Unroller<BroadcastDataSetter, VecSize, Arity>::step(ins, args, index_bc);
} else {
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
ins, args, configs, use_broadcast, block_offset, num, numel);
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read_lens是给XPU KP 使用的,此处代码已经被else包含为什么还要添加read_lens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上一个 comment 中,gpu部分做了特化,但 kp和 gpu 使用的是相同的非特化函数,参数就需要保持一致了

}

#endif
SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
VecSize,
Functor,
Expand Down
59 changes: 59 additions & 0 deletions paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,65 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}

/**
* @brief Read 1D data from global memory to register with broadcast form.
* The difference from the above function is that it supports different data
* types of inputs.
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* core_id() is used as the index.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template <typename T,
int NX,
int NY,
typename ArgsT,
int Index,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc(
ArgsT* dst,
const T _global_ptr_* src,
int block_offset,
const details::BroadcastConfig& config,
int total_num_output,
int read_lens = NX) {
int thread_offset = block_offset + core_id() * read_lens;
__local__ T in_temp[NX];

if (config.cmp_type == details::OptType::MNK_M1K) {
ReadDataBcM1kMnk<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::N_1) {
ReadDataBc1N<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_M) {
ReadDataBcM1Mn<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MN_N) {
ReadDataBc1NMn<T>(in_temp, src, thread_offset, config, read_lens);
} else if (config.cmp_type == details::OptType::MNK_1N1) {
ReadDataBc1N1Mnk<T>(in_temp, src, thread_offset, config, read_lens);
} else {
ReadDataBcCanNotCmp<T, IsBoundary>(
in_temp, src, thread_offset, config, total_num_output, read_lens);
}
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
std::get<Index>(dst[idx]) = in_temp[idx];
}
}

/**
* @brief Initialize register with data index.
*
Expand Down