Skip to content

Commit

Permalink
[FIX] Updating SYCL Shuffle API (#2737)
Browse files Browse the repository at this point in the history
  • Loading branch information
yisonzhu authored Jul 25, 2024
1 parent fec65e2 commit a990da2
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion itex/core/kernels/gpu/full_reduction_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct GroupReduceKernel<itex::int64, OutputT, InputFunctor, OutputFunctor,
InitValueT result = value;
#pragma unroll
for (int i = SubGroupSize / 2; i > 0; i >>= 1) {
InitValueT new_value = sg.shuffle_down(result, i);
InitValueT new_value = sycl::shift_group_left(sg, result, i);
result = op_(result, new_value);
}
return result;
Expand Down
3 changes: 2 additions & 1 deletion itex/core/kernels/gpu/softmax_op_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ template <template <typename> typename ReductionOp, typename T,
int workitem_group_width = kSubGroupSize>
T SubGroupAllReduce(const sycl::sub_group& sg, T val) {
for (int mask = workitem_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, sg.shuffle_xor(val, sycl::id<1>(mask)));
val = ReductionOp<T>()(
val, sycl::permute_group_by_xor(sg, val, sycl::id<1>(mask)));
}
return val;
}
Expand Down
3 changes: 2 additions & 1 deletion itex/core/kernels/gpu/sparse_xent_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ template <template <typename> typename ReductionOp, typename T,
int workitem_group_width = kSubGroupSize>
T SubGroupAllReduce(const sycl::sub_group& sg, T val) {
for (int mask = workitem_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, sg.shuffle_xor(val, sycl::id<1>(mask)));
val = ReductionOp<T>()(
val, sycl::permute_group_by_xor(sg, val, sycl::id<1>(mask)));
}
return val;
}
Expand Down
3 changes: 2 additions & 1 deletion itex/core/kernels/gpu/xent_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ template <template <typename> typename ReductionOp, typename T,
int workitem_group_width = kSubGroupSize>
T SubGroupAllReduce(const sycl::sub_group& sg, T val) {
for (int mask = workitem_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, sg.shuffle_xor(val, sycl::id<1>(mask)));
val = ReductionOp<T>()(
val, sycl::permute_group_by_xor(sg, val, sycl::id<1>(mask)));
}
return val;
}
Expand Down

0 comments on commit a990da2

Please sign in to comment.