Skip to content

Commit 67c4a8c

Browse files
committed
restore subgroup reduction for multi-subgroup thread blocks in norm kernels
1 parent 71145b5 commit 67c4a8c

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

ggml/src/ggml-sycl/norm.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
55

66
const int nrows = item_ct1.get_group_range(2);
77
const int nchannels = item_ct1.get_group_range(1);
8+
const int nthreads = item_ct1.get_local_range(2);
89
const int sample = item_ct1.get_group(0);
910
const int channel = item_ct1.get_group(1);
1011
const int row = item_ct1.get_group(2);
1112

1213
const int tid = item_ct1.get_local_id(2);
14+
const int nwarps = nthreads / WARP_SIZE;
1315

1416
x += sample * stride_sample + channel * stride_channel + row * stride_row;
1517
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
@@ -30,8 +32,12 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
3032
s_sum[warp_id] = mean_var;
3133
}
3234
item_ct1.barrier(sycl::access::fence_space::local_space);
33-
34-
mean_var = s_sum[lane_id];
35+
mean_var = 0.f;
36+
size_t nreduce = nwarps / WARP_SIZE;
37+
for (size_t i = 0; i < nreduce; i += 1)
38+
{
39+
mean_var += s_sum[lane_id + i * WARP_SIZE];
40+
}
3541
mean_var = warp_reduce_sum(mean_var, item_ct1);
3642
}
3743

@@ -139,8 +145,10 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
139145
const int sample = item_ct1.get_group(0);
140146
const int channel = item_ct1.get_group(1);
141147
const int row = item_ct1.get_group(2);
148+
const int nthreads = item_ct1.get_local_range(2);
142149

143150
const int tid = item_ct1.get_local_id(2);
151+
const int nwarps = nthreads / WARP_SIZE;
144152

145153
x += sample*stride_sample + channel*stride_channel + row*stride_row;
146154
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -164,7 +172,12 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
164172
}
165173

166174
item_ct1.barrier(sycl::access::fence_space::local_space);
167-
tmp = s_sum[lane_id];
175+
size_t nreduce = nwarps / WARP_SIZE;
176+
tmp = 0.f;
177+
for (size_t i = 0; i < nreduce; i += 1)
178+
{
179+
tmp += s_sum[lane_id + i * WARP_SIZE];
180+
}
168181
tmp = warp_reduce_sum(tmp, item_ct1);
169182
}
170183

0 commit comments

Comments
 (0)