@@ -5,11 +5,13 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
5
5
6
6
const int nrows = item_ct1.get_group_range (2 );
7
7
const int nchannels = item_ct1.get_group_range (1 );
8
+ const int nthreads = item_ct1.get_local_range (2 );
8
9
const int sample = item_ct1.get_group (0 );
9
10
const int channel = item_ct1.get_group (1 );
10
11
const int row = item_ct1.get_group (2 );
11
12
12
13
const int tid = item_ct1.get_local_id (2 );
14
+ const int nwarps = nthreads / WARP_SIZE;
13
15
14
16
x += sample * stride_sample + channel * stride_channel + row * stride_row;
15
17
dst += ((sample * nchannels + channel) * nrows + row) * ncols;
@@ -30,8 +32,12 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
30
32
s_sum[warp_id] = mean_var;
31
33
}
32
34
item_ct1.barrier (sycl::access ::fence_space::local_space);
33
-
34
- mean_var = s_sum[lane_id];
35
+ mean_var = 0 .f ;
36
+ size_t nreduce = nwarps / WARP_SIZE;
37
+ for (size_t i = 0 ; i < nreduce; i += 1 )
38
+ {
39
+ mean_var += s_sum[lane_id + i * WARP_SIZE];
40
+ }
35
41
mean_var = warp_reduce_sum (mean_var, item_ct1);
36
42
}
37
43
@@ -139,8 +145,10 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
139
145
const int sample = item_ct1.get_group (0 );
140
146
const int channel = item_ct1.get_group (1 );
141
147
const int row = item_ct1.get_group (2 );
148
+ const int nthreads = item_ct1.get_local_range (2 );
142
149
143
150
const int tid = item_ct1.get_local_id (2 );
151
+ const int nwarps = nthreads / WARP_SIZE;
144
152
145
153
x += sample*stride_sample + channel*stride_channel + row*stride_row;
146
154
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -164,7 +172,12 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
164
172
}
165
173
166
174
item_ct1.barrier (sycl::access ::fence_space::local_space);
167
- tmp = s_sum[lane_id];
175
+ size_t nreduce = nwarps / WARP_SIZE;
176
+ tmp = 0 .f ;
177
+ for (size_t i = 0 ; i < nreduce; i += 1 )
178
+ {
179
+ tmp += s_sum[lane_id + i * WARP_SIZE];
180
+ }
168
181
tmp = warp_reduce_sum (tmp, item_ct1);
169
182
}
170
183
0 commit comments