@@ -682,25 +682,27 @@ kernel void kernel_rope(
682
682
constant int & mode,
683
683
constant float & freq_base,
684
684
constant float & freq_scale,
685
- uint3 tpig[[thread_position_in_grid]]) {
686
- const int64_t i3 = tpig[2 ];
687
- const int64_t i2 = tpig[1 ];
688
- const int64_t i1 = tpig[0 ];
685
+ uint tiitg[[thread_index_in_threadgroup]],
686
+ uint3 tptg[[threads_per_threadgroup]],
687
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
688
+ const int64_t i3 = tgpig[2 ];
689
+ const int64_t i2 = tgpig[1 ];
690
+ const int64_t i1 = tgpig[0 ];
689
691
690
692
const bool is_neox = mode & 2 ;
691
- const float theta_scale = pow (freq_base, -2 .0f /n_dims);
692
693
693
694
const int64_t p = ((mode & 1 ) == 0 ? n_past + i2 : i2);
694
695
695
- float theta = freq_scale * (float )p;
696
+ const float theta_0 = freq_scale * (float )p;
697
+ const float inv_ndims = -1 .f /n_dims;
696
698
697
699
if (!is_neox) {
698
- for (int64_t i0 = 0 ; i0 < ne0; i0 += 2 ) {
700
+ for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg.x ) {
701
+
702
+ const float theta = theta_0 * pow (freq_base, inv_ndims*i0);
699
703
const float cos_theta = cos (theta);
700
704
const float sin_theta = sin (theta);
701
705
702
- theta *= theta_scale;
703
-
704
706
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
705
707
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
706
708
@@ -712,12 +714,12 @@ kernel void kernel_rope(
712
714
}
713
715
} else {
714
716
for (int64_t ib = 0 ; ib < ne0/n_dims; ++ib) {
715
- for (int64_t ic = 0 ; ic < n_dims; ic += 2 ) {
717
+ for (int64_t ic = 2 *tiitg; ic < n_dims; ic += 2 *tptg.x ) {
718
+
719
+ const float theta = theta_0 * pow (freq_base, inv_ndims*ic - ib);
716
720
const float cos_theta = cos (theta);
717
721
const float sin_theta = sin (theta);
718
722
719
- theta *= theta_scale;
720
-
721
723
const int64_t i0 = ib*n_dims + ic/2 ;
722
724
723
725
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
0 commit comments