File tree Expand file tree Collapse file tree 1 file changed +5
-25
lines changed Expand file tree Collapse file tree 1 file changed +5
-25
lines changed Original file line number Diff line number Diff line change @@ -220,29 +220,14 @@ kernel void kernel_norm(
220
220
}
221
221
threadgroup_barrier (mem_flags::mem_threadgroup);
222
222
}
223
- // broadcast
224
- if (tpitg == 0 ) {
225
- sum[0 ] /= ne00;
226
- }
227
- threadgroup_barrier (mem_flags::mem_threadgroup);
228
- const float mean = sum[0 ];
223
+ const float mean = sum[0 ] / ne00;
229
224
230
- // recenter
225
+ // recenter and VARIANCE
226
+ threadgroup_barrier (mem_flags::mem_threadgroup);
231
227
device float * y = dst + tgpig*ne00;
232
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
233
- y[i00] = x[i00] - mean;
234
- }
235
-
236
- // VARIANCE
237
- // parallel sum
238
- //
239
- // WARNING: combining this loop with the one above will give you wrong results for nth == 256
240
- // I have no idea why, so for now I am keeping them separate. But this behavior is very concerning.
241
- // Tested with:
242
- // ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4
243
- //
244
228
sum[tpitg] = 0 .0f ;
245
229
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
230
+ y[i00] = x[i00] - mean;
246
231
sum[tpitg] += y[i00] * y[i00];
247
232
}
248
233
@@ -254,12 +239,7 @@ kernel void kernel_norm(
254
239
}
255
240
threadgroup_barrier (mem_flags::mem_threadgroup);
256
241
}
257
- // broadcast
258
- if (tpitg == 0 ) {
259
- sum[0 ] /= ne00;
260
- }
261
- threadgroup_barrier (mem_flags::mem_threadgroup);
262
- const float variance = sum[0 ];
242
+ const float variance = sum[0 ] / ne00;
263
243
264
244
const float scale = 1 .0f /sqrt (variance + eps);
265
245
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
You can’t perform that action at this time.
0 commit comments