@@ -9033,18 +9033,20 @@ static void ggml_compute_forward_rms_norm_f32(
90339033 GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
90349034
90359035 if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9036+ atomic_store (params -> aic , 0 );
9037+
90369038 return ;
90379039 }
90389040
90399041 GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
90409042
9041- const int ith = params -> ith ;
9043+ const int ith = params -> ith ; UNUSED ( ith );
90429044 const int nth = params -> nth ;
90439045
90449046 const int64_t ne00 = src0 -> ne [0 ];
90459047 const int64_t ne01 = src0 -> ne [1 ];
90469048 const int64_t ne02 = src0 -> ne [2 ];
9047- const int64_t ne03 = src0 -> ne [3 ];
9049+ const int64_t ne03 = src0 -> ne [3 ]; UNUSED ( ne03 );
90489050
90499051 const size_t nb01 = src0 -> nb [1 ];
90509052 const size_t nb02 = src0 -> nb [2 ];
@@ -9056,30 +9058,45 @@ static void ggml_compute_forward_rms_norm_f32(
90569058
90579059 const float eps = 1e-6f ; // TODO: make this a parameter
90589060
9059- // TODO: optimize
9060- for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
9061- for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
9062- for (int64_t i01 = ith ; i01 < ne01 ; i01 += nth ) {
9063- const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
9064-
9065- ggml_float sum = 0.0 ;
9066- for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9067- sum += (ggml_float )(x [i00 ] * x [i00 ]);
9068- }
9061+ const int nr = ggml_nrows (src0 );
9062+ const int dr = (nr + 8 * nth - 1 )/(8 * nth );
90699063
9070- float mean = sum /ne00 ;
9064+ while (true) {
9065+ const int ir0 = atomic_fetch_add (params -> aic , dr );
90719066
9072- float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9067+ for (int ir = ir0 ; ir < ir0 + dr ; ++ ir ) {
9068+ if (ir >= nr ) {
9069+ break ;
9070+ }
90739071
9074- memcpy ( y , x , ne00 * sizeof ( float ));
9075- // for ( int i00 = 0; i00 < ne00; i00++) {
9076- // y[i00] = x[i00] ;
9077- // }
9072+ // src0 indices
9073+ const int i03 = ir /( ne02 * ne01 );
9074+ const int i02 = ( ir - i03 * ne02 * ne01 )/ ne01 ;
9075+ const int i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 );
90789076
9079- const float scale = 1.0f / sqrtf ( mean + eps );
9077+ const float * x = ( float * ) (( char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
90809078
9081- ggml_vec_scale_f32 (ne00 , y , scale );
9079+ ggml_float sum = 0.0 ;
9080+ for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9081+ sum += (ggml_float )(x [i00 ] * x [i00 ]);
90829082 }
9083+
9084+ float mean = sum /ne00 ;
9085+
9086+ float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9087+
9088+ memcpy (y , x , ne00 * sizeof (float ));
9089+ // for (int i00 = 0; i00 < ne00; i00++) {
9090+ // y[i00] = x[i00];
9091+ // }
9092+
9093+ const float scale = 1.0f /sqrtf (mean + eps );
9094+
9095+ ggml_vec_scale_f32 (ne00 , y , scale );
9096+ }
9097+
9098+ if (ir0 + dr >= nr ) {
9099+ break ;
90839100 }
90849101 }
90859102}
@@ -9754,11 +9771,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
97549771 const int nb2 = dst -> nb [2 ];
97559772 const int nb3 = dst -> nb [3 ];
97569773
9757- const int ith = params -> ith ;
9774+ const int ith = params -> ith ; UNUSED ( ith );
97589775 const int nth = params -> nth ;
97599776
9760- UNUSED (ith );
9761-
97629777 GGML_ASSERT (ne02 == ne12 );
97639778 GGML_ASSERT (ne03 == ne13 );
97649779 GGML_ASSERT (ne2 == ne12 );
0 commit comments