@@ -110,10 +110,29 @@ struct BroadcastDataLoader {
110110 const Array3 &use_broadcast,
111111 const int block_offset,
112112 const int num,
113- const uint32_t numel) {
113+ const uint32_t numel,
114+ int read_lens) {
114115 using Type = std::tuple_element_t <Index, ArgsT>;
116+ #ifdef PADDLE_WITH_XPU_KP
117+ kps::Init<Type, ArgsT, Index, VecSize>(
118+ args, static_cast <Type>(1 .0f ), read_lens);
119+ if (use_broadcast[Index]) {
120+ kps::ReadDataBc<Type, VecSize, 1 , ArgsT, Index, IsBoundary>(
121+ args,
122+ reinterpret_cast <const _ptr_ Type *>(ins[Index]),
123+ block_offset,
124+ configs[Index],
125+ numel,
126+ read_lens);
127+ } else {
128+ kps::ReadData<Type, VecSize, 1 , ArgsT, Index, IsBoundary>(
129+ args,
130+ reinterpret_cast <const _ptr_ Type *>(ins[Index]) + block_offset,
131+ num,
132+ read_lens);
133+ }
134+ #else
115135 kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast <Type>(1 .0f ));
116-
117136 if (use_broadcast[Index]) {
118137 kps::ReadDataBc<Type, VecSize, 1 , ArgsT, Index, IsBoundary>(
119138 args,
@@ -133,6 +152,7 @@ struct BroadcastDataLoader {
133152 num,
134153 VecSize);
135154 }
155+ #endif
136156 }
137157};
138158
@@ -148,7 +168,8 @@ struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
148168 const Array3 &use_broadcast,
149169 const int block_offset,
150170 const int num,
151- const uint32_t numel) {
171+ const uint32_t numel,
172+ int read_lens) {
152173 using Type = std::tuple_element_t <Index, ArgsT>;
153174 int thread_offset = threadIdx.x * VecSize + block_offset;
154175#pragma unroll
@@ -173,7 +194,8 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
173194 const Array3 &use_broadcast,
174195 const int block_offset,
175196 const int num,
176- const uint32_t numel) {
197+ const uint32_t numel,
198+ int read_lens) {
177199 using Type = std::tuple_element_t <Index, ArgsT>;
178200 using VecType = phi::kps::details::VectorType<Type, VecSize>;
179201 VecType vec_temp;
@@ -269,6 +291,10 @@ __device__ void VectorizedBroadcastKernelImpl(
269291 __simd__ ArgsT args[VecSize];
270292 __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
271293
294+ #ifdef PADDLE_WITH_XPU_KP
295+ BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step (
296+ ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
297+ #else
272298 if (LoadType == kBroadcast ) {
273299 uint32_t index_bc[Arity][VecSize] = {0 };
274300 Unroller<BroadcastDataInit, VecSize, Arity>::step (args);
@@ -291,9 +317,9 @@ __device__ void VectorizedBroadcastKernelImpl(
291317 Unroller<BroadcastDataSetter, VecSize, Arity>::step (ins, args, index_bc);
292318 } else {
293319 BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step (
294- ins, args, configs, use_broadcast, block_offset, num, numel);
320+ ins, args, configs, use_broadcast, block_offset, num, numel, read_lens );
295321 }
296-
322+ # endif
297323 SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
298324 VecSize,
299325 Functor,
0 commit comments