@@ -6079,6 +6079,7 @@ void kernel_mul_mv_iq4_ks_f32_impl(
60796079
60806080 float4 yl[4 ];
60816081 float2 sumf = 0 .f ;
6082+ float d[2 ];
60826083
60836084 device const float * yb = y + ix * QK_K + ib * 32 + il * 8 ;
60846085
@@ -6087,22 +6088,25 @@ void kernel_mul_mv_iq4_ks_f32_impl(
60876088
60886089 float4 qf1, qf2;
60896090
6091+ device const float * dptr = (device const float *)cx;
6092+ d[0 ] = *dptr;
6093+ device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1 ) + ix;
6094+ dptr += row_size/4 ;
6095+ d[1 ] = *dptr;
6096+
60906097 for (int ibl = ix; ibl < nb; ibl += 2 ) {
60916098
60926099 device const float4 * y4 = (device const float4 *)yb;
60936100 yl[0 ] = y4[0 ]; yl[1 ] = y4[4 ]; yl[2 ] = y4[1 ]; yl[3 ] = y4[5 ];
60946101
6095- device const float * dptr = (device const float *)cx ;
6102+ device const uint8_t * scales = x-> scales ;
60966103
60976104 for (int row = 0 ; row < 2 ; ++row) {
60986105
6099- // device const float * dptr = (device const float *)(cx + row*row_size);
6100- const float d = *dptr;
6101- device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1 );
6102- device const block_iq4_ks & xb = x[ibl];
6103- device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16 *ib + 8 *il);
6106+ threadgroup const float * block_values = shared_values + ((scales[ib] & 1 ) << 4 );
6107+ const float ls = ((scales[ib] & 254 ) - 127 );
61046108
6105- threadgroup const float * block_values = shared_values + ((xb. scales [ib] & 1 ) << 4 ) ;
6109+ device const uint32_t * q4 = (device const uint32_t *) scales + QK_K/ 128 + 4 *ib + 2 *il ;
61066110
61076111 float4 acc1 = {0 .f }, acc2 = {0 .f };
61086112
@@ -6122,14 +6126,14 @@ void kernel_mul_mv_iq4_ks_f32_impl(
61226126
61236127 acc1 += acc2;
61246128
6125- const int ls = (xb.scales [ib] & 254 ) - 127 ;
6126- sumf[row] += d * ls * (acc1[0 ] + acc1[1 ] + acc1[2 ] + acc1[3 ]);
6129+ sumf[row] += d[row] * ls * (acc1[0 ] + acc1[1 ] + acc1[2 ] + acc1[3 ]);
61276130
6128- dptr += row_size/ 4 ;
6131+ scales += row_size;
61296132
61306133 }
61316134
61326135 yb += 2 * QK_K;
6136+ x += 2 ;
61336137 }
61346138
61356139 sumf = simd_sum (sumf);
0 commit comments