Skip to content

Commit 8d340ee

Browse files
authored
Fix xpu2 kp compile error (#53548)
1 parent 727fa27 commit 8d340ee

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,29 @@ struct BroadcastDataLoader {
110110
const Array3 &use_broadcast,
111111
const int block_offset,
112112
const int num,
113-
const uint32_t numel) {
113+
const uint32_t numel,
114+
int read_lens) {
114115
using Type = std::tuple_element_t<Index, ArgsT>;
116+
#ifdef PADDLE_WITH_XPU_KP
117+
kps::Init<Type, ArgsT, Index, VecSize>(
118+
args, static_cast<Type>(1.0f), read_lens);
119+
if (use_broadcast[Index]) {
120+
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
121+
args,
122+
reinterpret_cast<const _ptr_ Type *>(ins[Index]),
123+
block_offset,
124+
configs[Index],
125+
numel,
126+
read_lens);
127+
} else {
128+
kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
129+
args,
130+
reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
131+
num,
132+
read_lens);
133+
}
134+
#else
115135
kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
116-
117136
if (use_broadcast[Index]) {
118137
kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
119138
args,
@@ -133,6 +152,7 @@ struct BroadcastDataLoader {
133152
num,
134153
VecSize);
135154
}
155+
#endif
136156
}
137157
};
138158

@@ -148,7 +168,8 @@ struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
148168
const Array3 &use_broadcast,
149169
const int block_offset,
150170
const int num,
151-
const uint32_t numel) {
171+
const uint32_t numel,
172+
int read_lens) {
152173
using Type = std::tuple_element_t<Index, ArgsT>;
153174
int thread_offset = threadIdx.x * VecSize + block_offset;
154175
#pragma unroll
@@ -173,7 +194,8 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
173194
const Array3 &use_broadcast,
174195
const int block_offset,
175196
const int num,
176-
const uint32_t numel) {
197+
const uint32_t numel,
198+
int read_lens) {
177199
using Type = std::tuple_element_t<Index, ArgsT>;
178200
using VecType = phi::kps::details::VectorType<Type, VecSize>;
179201
VecType vec_temp;
@@ -269,6 +291,10 @@ __device__ void VectorizedBroadcastKernelImpl(
269291
__simd__ ArgsT args[VecSize];
270292
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
271293

294+
#ifdef PADDLE_WITH_XPU_KP
295+
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
296+
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
297+
#else
272298
if (LoadType == kBroadcast) {
273299
uint32_t index_bc[Arity][VecSize] = {0};
274300
Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
@@ -291,9 +317,9 @@ __device__ void VectorizedBroadcastKernelImpl(
291317
Unroller<BroadcastDataSetter, VecSize, Arity>::step(ins, args, index_bc);
292318
} else {
293319
BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
294-
ins, args, configs, use_broadcast, block_offset, num, numel);
320+
ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
295321
}
296-
322+
#endif
297323
SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
298324
VecSize,
299325
Functor,

paddle/phi/kernels/primitive/datamover_primitives_xpu2.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,65 @@ __device__ __inline__ void ReadDataBc(T* dst,
12111211
}
12121212
}
12131213

1214+
/**
1215+
* @brief Read 1D data from global memory to register with broadcast form.
1216+
* The difference from the above function is that it supports different data
1217+
* types of inputs.
1218+
* @template paraments
1219+
* T: The type of data stored in the global memory.
1220+
* NX: The number of data continuously loaded by each thread.
1221+
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
1222+
* core_id() is used as the index.
1223+
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
1224+
* judgment. When the number of data processed by the block is less than
1225+
* NX x NY x core_num(), boundary judgment is required to avoid memory access
1226+
* crossing the boundary.
1227+
*
1228+
* @param:
1229+
* dst: The register pointer of the thread, the size is NX * NY.
1230+
* src: The original input data pointer of kernel.
1231+
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
1232+
* config: Calculation configuration of broadcast. It is used to calculate the
1233+
* coordinate mapping relationship between output data and input data.
1234+
* read_lens: The number of data continuously loaded by each thread.
1235+
* total_num_output: Total number of original output.
1236+
*/
1237+
template <typename T,
1238+
int NX,
1239+
int NY,
1240+
typename ArgsT,
1241+
int Index,
1242+
bool IsBoundary = false>
1243+
__device__ __forceinline__ void ReadDataBc(
1244+
ArgsT* dst,
1245+
const T _global_ptr_* src,
1246+
int block_offset,
1247+
const details::BroadcastConfig& config,
1248+
int total_num_output,
1249+
int read_lens = NX) {
1250+
int thread_offset = block_offset + core_id() * read_lens;
1251+
__local__ T in_temp[NX];
1252+
1253+
if (config.cmp_type == details::OptType::MNK_M1K) {
1254+
ReadDataBcM1kMnk<T>(in_temp, src, thread_offset, config, read_lens);
1255+
} else if (config.cmp_type == details::OptType::N_1) {
1256+
ReadDataBc1N<T>(in_temp, src, thread_offset, config, read_lens);
1257+
} else if (config.cmp_type == details::OptType::MN_M) {
1258+
ReadDataBcM1Mn<T>(in_temp, src, thread_offset, config, read_lens);
1259+
} else if (config.cmp_type == details::OptType::MN_N) {
1260+
ReadDataBc1NMn<T>(in_temp, src, thread_offset, config, read_lens);
1261+
} else if (config.cmp_type == details::OptType::MNK_1N1) {
1262+
ReadDataBc1N1Mnk<T>(in_temp, src, thread_offset, config, read_lens);
1263+
} else {
1264+
ReadDataBcCanNotCmp<T, IsBoundary>(
1265+
in_temp, src, thread_offset, config, total_num_output, read_lens);
1266+
}
1267+
#pragma unroll
1268+
for (int idx = 0; idx < read_lens; ++idx) {
1269+
std::get<Index>(dst[idx]) = in_temp[idx];
1270+
}
1271+
}
1272+
12141273
/**
12151274
* @brief Initialize register with data index.
12161275
*

0 commit comments

Comments
 (0)