@@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32(
52325232 memcpy (&scale, (float *) dst->op_params + 0 , sizeof (float ));
52335233 memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
52345234
5235- // TODO: handle transposed/permuted matrices
5236-
52375235 const int ith = params->ith ;
52385236 const int nth = params->nth ;
52395237
52405238 GGML_TENSOR_UNARY_OP_LOCALS
52415239
5242- // const int64_t ne11 = src1 ? src1->ne[1] : 1;
5240+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
5241+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
5242+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
5243+
5244+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
5245+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
52435246
52445247 // TODO: is this supposed to be ceil instead of floor?
52455248 // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32(
52495252 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
52505253 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
52515254
5252- const int nc = src0->ne [0 ];
5253- const int nr = ggml_nrows (src0);
5254-
5255- // rows per thread
5256- const int dr = (nr + nth - 1 )/nth;
5257-
5258- // row range for this thread
5259- const int ir0 = dr*ith;
5260- const int ir1 = MIN (ir0 + dr, nr);
5261-
5262- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5255+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
52635256
52645257 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
52655258
5266- for (int i1 = ir0; i1 < ir1; i1++) {
5267- // ALiBi
5268- const uint32_t h = (i1/ne01)%ne02; // head
5269- const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
5270-
5271- float * sp = (float *)((char *) src0->data + i1*src0->nb [1 ]);
5272- float * dp = (float *)((char *) dst->data + i1*dst->nb [1 ]);
5273-
5274- // broadcast the mask across rows
5275- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
5276- float * mp_f32 = src1 ? (float *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
5277-
5278- ggml_vec_cpy_f32 (nc, wp, sp);
5279- ggml_vec_scale_f32 (nc, wp, scale);
5280- if (mp_f32) {
5281- if (use_f16) {
5282- for (int i = 0 ; i < nc; ++i) {
5283- wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
5284- }
5285- } else {
5286- for (int i = 0 ; i < nc; ++i) {
5287- wp[i] += slope*mp_f32[i];
5259+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
5260+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
5261+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5262+ const int64_t i11 = i01;
5263+ const int64_t i12 = i02%ne12;
5264+ const int64_t i13 = i03%ne13;
5265+
5266+ // ALiBi
5267+ const uint32_t h = i02; // head
5268+ const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
5269+
5270+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5271+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5272+
5273+ // broadcast the mask across rows
5274+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
5275+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
5276+
5277+ ggml_vec_cpy_f32 (ne00, wp, sp);
5278+ ggml_vec_scale_f32 (ne00, wp, scale);
5279+ if (mp_f32) {
5280+ if (use_f16) {
5281+ for (int i = 0 ; i < ne00; ++i) {
5282+ wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
5283+ }
5284+ } else {
5285+ for (int i = 0 ; i < ne00; ++i) {
5286+ wp[i] += slope*mp_f32[i];
5287+ }
5288+ }
52885289 }
5289- }
5290- }
52915290
52925291#ifndef NDEBUG
5293- for (int i = 0 ; i < nc ; ++i) {
5294- // printf("p[%d] = %f\n", i, p[i]);
5295- assert (!isnan (wp[i]));
5296- }
5292+ for (int i = 0 ; i < ne00 ; ++i) {
5293+ // printf("p[%d] = %f\n", i, p[i]);
5294+ assert (!isnan (wp[i]));
5295+ }
52975296#endif
52985297
5299- float max = -INFINITY;
5300- ggml_vec_max_f32 (nc , &max, wp);
5298+ float max = -INFINITY;
5299+ ggml_vec_max_f32 (ne00 , &max, wp);
53015300
5302- ggml_float sum = ggml_vec_soft_max_f32 (nc , dp, wp, max);
5303- assert (sum > 0.0 );
5301+ ggml_float sum = ggml_vec_soft_max_f32 (ne00 , dp, wp, max);
5302+ assert (sum > 0.0 );
53045303
5305- sum = 1.0 /sum;
5306- ggml_vec_scale_f32 (nc , dp, sum);
5304+ sum = 1.0 /sum;
5305+ ggml_vec_scale_f32 (ne00 , dp, sum);
53075306
53085307#ifndef NDEBUG
5309- for (int i = 0 ; i < nc ; ++i) {
5310- assert (!isnan (dp[i]));
5311- assert (!isinf (dp[i]));
5312- }
5308+ for (int i = 0 ; i < ne00 ; ++i) {
5309+ assert (!isnan (dp[i]));
5310+ assert (!isinf (dp[i]));
5311+ }
53135312#endif
5313+ }
5314+ }
53145315 }
53155316}
53165317
@@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77667767 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
77677768 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
77687769
7769- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7770+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
77707771 ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu (k_vec_dot_type)->from_float ;
77717772 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
77727773 ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
@@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77987799 memset (VKQ32, 0 , DV*sizeof (float ));
77997800 }
78007801
7801- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ]) : NULL ;
7802+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ] + (iq3%mask-> ne [ 2 ])*mask-> nb [ 2 ] ) : NULL ;
78027803
78037804 // k indices
78047805 const int ik3 = iq3 / rk3;
0 commit comments