Skip to content

Commit

Permalink
ggml : rms_norm in chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 17, 2023
1 parent add49f6 commit 14a7dce
Showing 1 changed file with 38 additions and 23 deletions.
61 changes: 38 additions & 23 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -9018,18 +9018,20 @@ static void ggml_compute_forward_rms_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
atomic_store(params->aic, 0);

return;
}

GGML_ASSERT(src0->nb[0] == sizeof(float));

const int ith = params->ith;
const int ith = params->ith; UNUSED(ith);
const int nth = params->nth;

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne03 = src0->ne[3]; UNUSED(ne03);

const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
Expand All @@ -9041,30 +9043,45 @@ static void ggml_compute_forward_rms_norm_f32(

const float eps = 1e-6f; // TODO: make this a parameter

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);

ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}
const int nr = ggml_nrows(src0);
const int dr = (nr + 8*nth - 1)/(8*nth);

float mean = sum/ne00;
while (true) {
const int ir0 = atomic_fetch_add(params->aic, dr);

float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
for (int ir = ir0; ir < ir0 + dr; ++ir) {
if (ir >= nr) {
break;
}

memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);

const float scale = 1.0f/sqrtf(mean + eps);
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);

ggml_vec_scale_f32(ne00, y, scale);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}

float mean = sum/ne00;

float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);

memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }

const float scale = 1.0f/sqrtf(mean + eps);

ggml_vec_scale_f32(ne00, y, scale);
}

if (ir0 + dr >= nr) {
break;
}
}
}
Expand Down Expand Up @@ -9739,11 +9756,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

const int ith = params->ith;
const int ith = params->ith; UNUSED(ith);
const int nth = params->nth;

UNUSED(ith);

GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
Expand Down

0 comments on commit 14a7dce

Please sign in to comment.