@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
11638
11638
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
11639
11639
}
11640
11640
11641
+ static void ggml_rope_cache_init(
11642
+ float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
11643
+ float * cache, float sin_sign, float theta_scale
11644
+ ) {
11645
+ float theta = theta_base;
11646
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11647
+ rope_yarn(
11648
+ theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
11649
+ );
11650
+ cache[i0 + 1] *= sin_sign;
11651
+
11652
+ theta *= theta_scale;
11653
+ }
11654
+ }
11655
+
11641
11656
void ggml_rope_yarn_corr_dims(
11642
11657
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
11643
11658
) {
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
11720
11735
for (int64_t i3 = 0; i3 < ne3; i3++) {
11721
11736
for (int64_t i2 = 0; i2 < ne2; i2++) {
11722
11737
const int64_t p = pos[i2];
11738
+
11739
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11740
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11741
+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11742
+ }
11743
+
11723
11744
for (int64_t i1 = 0; i1 < ne1; i1++) {
11724
11745
if (ir++ < ir0) continue;
11725
11746
if (ir > ir1) break;
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
11753
11774
}
11754
11775
} else if (!is_neox) {
11755
11776
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11756
- float cos_theta, sin_theta;
11757
- rope_yarn(
11758
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11759
- );
11760
- sin_theta *= sin_sign;
11777
+ const float cos_theta = cache[i0 + 0];
11778
+ const float sin_theta = cache[i0 + 1];
11761
11779
11762
11780
// zeta scaling for xPos only:
11763
11781
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
11764
11782
if (xpos_down) zeta = 1.0f / zeta;
11765
11783
11766
- theta_base *= theta_scale;
11767
-
11768
11784
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11769
11785
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11770
11786
@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
11888
11904
for (int64_t i3 = 0; i3 < ne3; i3++) {
11889
11905
for (int64_t i2 = 0; i2 < ne2; i2++) {
11890
11906
const int64_t p = pos[i2];
11907
+
11908
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11909
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
11910
+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11911
+ }
11912
+
11891
11913
for (int64_t i1 = 0; i1 < ne1; i1++) {
11892
11914
if (ir++ < ir0) continue;
11893
11915
if (ir > ir1) break;
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
11921
11943
}
11922
11944
} else if (!is_neox) {
11923
11945
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11924
- float cos_theta, sin_theta;
11925
- rope_yarn(
11926
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11927
- );
11928
- sin_theta *= sin_sign;
11929
-
11930
- theta_base *= theta_scale;
11946
+ const float cos_theta = cache[i0 + 0];
11947
+ const float sin_theta = cache[i0 + 1];
11931
11948
11932
11949
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11933
11950
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
16722
16739
}
16723
16740
} break;
16724
16741
case GGML_OP_SOFT_MAX:
16742
+ case GGML_OP_ROPE:
16725
16743
{
16726
16744
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
16727
16745
} break;
0 commit comments