@@ -536,14 +536,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
536
536
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
537
537
538
538
float sumf = 0 ;
539
- for (int i = tiisg; i < ne00; i += 32 ) {
540
- sumf += (float ) x[i] * (float ) y[i];
539
+ if (ne00 < 128 ) {
540
+ for (int i = tiisg; i < ne00; i += 32 ) {
541
+ sumf += (float ) x[i] * (float ) y[i];
542
+ }
543
+ float all_sum = simd_sum (sumf);
544
+ if (tiisg == 0 ) {
545
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
546
+ }
547
+ } else {
548
+ device const half4 * x4 = (device const half4 *) x;
549
+ device const float4 * y4 = (device const float4 *) y;
550
+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
551
+ for (int k = 0 ; k < 4 ; ++k) sumf += (float )x4[i][k] * y4[i][k];
552
+ }
553
+ float all_sum = simd_sum (sumf);
554
+ if (tiisg == 0 ) {
555
+ for (int i = 4 *(ne00/4 ); i < ne00; ++i) all_sum += (float ) x[i] * y[i];
556
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
557
+ }
541
558
}
542
559
543
- float all_sum = simd_sum (sumf);
544
- if (tiisg == 0 ) {
545
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
546
- }
547
560
}
548
561
549
562
#define N_F16_F32 4
@@ -570,29 +583,54 @@ kernel void kernel_mul_mat_f16_f32(
570
583
uint tiisg[[thread_index_in_simdgroup]]) {
571
584
572
585
const int64_t r0 = tgpig.x ;
573
- const int64_t rb = N_F16_F32* tgpig.y ;
586
+ const int64_t rb = tgpig.y *N_F16_F32 ;
574
587
const int64_t im = tgpig.z ;
575
588
576
589
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
577
590
578
- for (int row = 0 ; row < N_F16_F32; ++row) {
579
- int r1 = rb + row;
580
- if (r1 >= ne11) {
581
- break ;
582
- }
591
+ if (ne00 < 128 ) {
592
+ for (int row = 0 ; row < N_F16_F32; ++row) {
593
+ int r1 = rb + row;
594
+ if (r1 >= ne11) {
595
+ break ;
596
+ }
583
597
584
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
598
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
585
599
586
- float sumf = 0 ;
587
- for (int i = tiisg; i < ne00; i += 32 ) {
588
- sumf += (float ) x[i] * (float ) y[i];
600
+ float sumf = 0 ;
601
+ for (int i = tiisg; i < ne00; i += 32 ) {
602
+ sumf += (float ) x[i] * (float ) y[i];
603
+ }
604
+
605
+ float all_sum = simd_sum (sumf);
606
+ if (tiisg == 0 ) {
607
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
608
+ }
589
609
}
610
+ } else {
611
+ device const half4 * x4 = (device const half4 *)x;
612
+ for (int row = 0 ; row < N_F16_F32; ++row) {
613
+ int r1 = rb + row;
614
+ if (r1 >= ne11) {
615
+ break ;
616
+ }
590
617
591
- float all_sum = simd_sum (sumf);
592
- if (tiisg == 0 ) {
593
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
618
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
619
+ device const float4 * y4 = (device const float4 *) y;
620
+
621
+ float sumf = 0 ;
622
+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
623
+ for (int k = 0 ; k < 4 ; ++k) sumf += (float ) x4[i][k] * y4[i][k];
624
+ }
625
+
626
+ float all_sum = simd_sum (sumf);
627
+ if (tiisg == 0 ) {
628
+ for (int i = 4 *(ne00/4 ); i < ne00; ++i) all_sum += (float ) x[i] * y[i];
629
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
630
+ }
594
631
}
595
632
}
633
+
596
634
}
597
635
598
636
kernel void kernel_alibi_f32 (
0 commit comments