Skip to content

Commit b7f2aa9

Browse files
authored
metal : restore 363f0bf and fix reduce in F16_F32 kernels (#2986)
1 parent 73a12a6 commit b7f2aa9

File tree

1 file changed

+57
-19
lines changed

1 file changed

+57
-19
lines changed

ggml-metal.metal

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -536,14 +536,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
536536
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
537537

538538
float sumf = 0;
539-
for (int i = tiisg; i < ne00; i += 32) {
540-
sumf += (float) x[i] * (float) y[i];
539+
if (ne00 < 128) {
540+
for (int i = tiisg; i < ne00; i += 32) {
541+
sumf += (float) x[i] * (float) y[i];
542+
}
543+
float all_sum = simd_sum(sumf);
544+
if (tiisg == 0) {
545+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
546+
}
547+
} else {
548+
device const half4 * x4 = (device const half4 *) x;
549+
device const float4 * y4 = (device const float4 *) y;
550+
for (int i = tiisg; i < ne00/4; i += 32) {
551+
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
552+
}
553+
float all_sum = simd_sum(sumf);
554+
if (tiisg == 0) {
555+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
556+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
557+
}
541558
}
542559

543-
float all_sum = simd_sum(sumf);
544-
if (tiisg == 0) {
545-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
546-
}
547560
}
548561

549562
#define N_F16_F32 4
@@ -570,29 +583,54 @@ kernel void kernel_mul_mat_f16_f32(
570583
uint tiisg[[thread_index_in_simdgroup]]) {
571584

572585
const int64_t r0 = tgpig.x;
573-
const int64_t rb = N_F16_F32*tgpig.y;
586+
const int64_t rb = tgpig.y*N_F16_F32;
574587
const int64_t im = tgpig.z;
575588

576589
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
577590

578-
for (int row = 0; row < N_F16_F32; ++row) {
579-
int r1 = rb + row;
580-
if (r1 >= ne11) {
581-
break;
582-
}
591+
if (ne00 < 128) {
592+
for (int row = 0; row < N_F16_F32; ++row) {
593+
int r1 = rb + row;
594+
if (r1 >= ne11) {
595+
break;
596+
}
583597

584-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
598+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
585599

586-
float sumf = 0;
587-
for (int i = tiisg; i < ne00; i += 32) {
588-
sumf += (float) x[i] * (float) y[i];
600+
float sumf = 0;
601+
for (int i = tiisg; i < ne00; i += 32) {
602+
sumf += (float) x[i] * (float) y[i];
603+
}
604+
605+
float all_sum = simd_sum(sumf);
606+
if (tiisg == 0) {
607+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
608+
}
589609
}
610+
} else {
611+
device const half4 * x4 = (device const half4 *)x;
612+
for (int row = 0; row < N_F16_F32; ++row) {
613+
int r1 = rb + row;
614+
if (r1 >= ne11) {
615+
break;
616+
}
590617

591-
float all_sum = simd_sum(sumf);
592-
if (tiisg == 0) {
593-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
618+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
619+
device const float4 * y4 = (device const float4 *) y;
620+
621+
float sumf = 0;
622+
for (int i = tiisg; i < ne00/4; i += 32) {
623+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
624+
}
625+
626+
float all_sum = simd_sum(sumf);
627+
if (tiisg == 0) {
628+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
629+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
630+
}
594631
}
595632
}
633+
596634
}
597635

598636
kernel void kernel_alibi_f32(

0 commit comments

Comments
 (0)