Skip to content

Commit be6beeb

Browse files
ikawrakowKawrakowggerganov
authored
metal : correct fix of kernel_norm (#3060)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent c4f4966 commit be6beeb

File tree

1 file changed

+5
-25
lines changed

1 file changed

+5
-25
lines changed

ggml-metal.metal

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -220,29 +220,14 @@ kernel void kernel_norm(
220220
}
221221
threadgroup_barrier(mem_flags::mem_threadgroup);
222222
}
223-
// broadcast
224-
if (tpitg == 0) {
225-
sum[0] /= ne00;
226-
}
227-
threadgroup_barrier(mem_flags::mem_threadgroup);
228-
const float mean = sum[0];
223+
const float mean = sum[0] / ne00;
229224

230-
// recenter
225+
// recenter and VARIANCE
226+
threadgroup_barrier(mem_flags::mem_threadgroup);
231227
device float * y = dst + tgpig*ne00;
232-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
233-
y[i00] = x[i00] - mean;
234-
}
235-
236-
// VARIANCE
237-
// parallel sum
238-
//
239-
// WARNING: combining this loop with the one above will give you wrong results for nth == 256
240-
// I have no idea why, so for now I am keeping them separate. But this behavior is very concerning.
241-
// Tested with:
242-
// ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4
243-
//
244228
sum[tpitg] = 0.0f;
245229
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
230+
y[i00] = x[i00] - mean;
246231
sum[tpitg] += y[i00] * y[i00];
247232
}
248233

@@ -254,12 +239,7 @@ kernel void kernel_norm(
254239
}
255240
threadgroup_barrier(mem_flags::mem_threadgroup);
256241
}
257-
// broadcast
258-
if (tpitg == 0) {
259-
sum[0] /= ne00;
260-
}
261-
threadgroup_barrier(mem_flags::mem_threadgroup);
262-
const float variance = sum[0];
242+
const float variance = sum[0] / ne00;
263243

264244
const float scale = 1.0f/sqrt(variance + eps);
265245
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {

0 commit comments

Comments
 (0)