Skip to content

Commit c71d608

Browse files
ggml: cache sin/cos for RoPE (#4908)
1 parent 4be5ef5 commit c71d608

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

ggml.c

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
1163811638
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
1163911639
}
1164011640

11641+
static void ggml_rope_cache_init(
11642+
float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11643+
float * cache, float sin_sign, float theta_scale
11644+
) {
11645+
float theta = theta_base;
11646+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11647+
rope_yarn(
11648+
theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
11649+
);
11650+
cache[i0 + 1] *= sin_sign;
11651+
11652+
theta *= theta_scale;
11653+
}
11654+
}
11655+
1164111656
void ggml_rope_yarn_corr_dims(
1164211657
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1164311658
) {
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
1172011735
for (int64_t i3 = 0; i3 < ne3; i3++) {
1172111736
for (int64_t i2 = 0; i2 < ne2; i2++) {
1172211737
const int64_t p = pos[i2];
11738+
11739+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11740+
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11741+
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11742+
}
11743+
1172311744
for (int64_t i1 = 0; i1 < ne1; i1++) {
1172411745
if (ir++ < ir0) continue;
1172511746
if (ir > ir1) break;
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
1175311774
}
1175411775
} else if (!is_neox) {
1175511776
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11756-
float cos_theta, sin_theta;
11757-
rope_yarn(
11758-
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11759-
);
11760-
sin_theta *= sin_sign;
11777+
const float cos_theta = cache[i0 + 0];
11778+
const float sin_theta = cache[i0 + 1];
1176111779

1176211780
// zeta scaling for xPos only:
1176311781
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
1176411782
if (xpos_down) zeta = 1.0f / zeta;
1176511783

11766-
theta_base *= theta_scale;
11767-
1176811784
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1176911785
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1177011786

@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
1188811904
for (int64_t i3 = 0; i3 < ne3; i3++) {
1188911905
for (int64_t i2 = 0; i2 < ne2; i2++) {
1189011906
const int64_t p = pos[i2];
11907+
11908+
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11909+
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11910+
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11911+
}
11912+
1189111913
for (int64_t i1 = 0; i1 < ne1; i1++) {
1189211914
if (ir++ < ir0) continue;
1189311915
if (ir > ir1) break;
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
1192111943
}
1192211944
} else if (!is_neox) {
1192311945
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11924-
float cos_theta, sin_theta;
11925-
rope_yarn(
11926-
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11927-
);
11928-
sin_theta *= sin_sign;
11929-
11930-
theta_base *= theta_scale;
11946+
const float cos_theta = cache[i0 + 0];
11947+
const float sin_theta = cache[i0 + 1];
1193111948

1193211949
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1193311950
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1672216739
}
1672316740
} break;
1672416741
case GGML_OP_SOFT_MAX:
16742+
case GGML_OP_ROPE:
1672516743
{
1672616744
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1672716745
} break;

0 commit comments

Comments
 (0)