@@ -6313,7 +6313,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63136313    simdgroup_T8x8    ma[4 ];
63146314    simdgroup_half8x8 mb[2 ];
63156315    simdgroup_half8x8 mc[8 ];
6316-     for  (int  i = 0 ; i < 8 ; i++){
6316+     for  (short  i = 0 ; i < 8 ; i++){
63176317        mc[i] = make_filled_simdgroup_matrix<half, 8 >(0 .h );
63186318    }
63196319
@@ -6339,7 +6339,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63396339        threadgroup_barrier (mem_flags::mem_threadgroup);
63406340
63416341        #pragma  unroll(16)
6342-         for  (int  i = 0 ; i < 16 ; i++) {
6342+         for  (short  i = 0 ; i < 16 ; i++) {
63436343            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
63446344            +                     (tiitg % THREAD_PER_ROW) * 16  + (i / 8 ) * 8 ) \
63456345            +                     (tiitg / THREAD_PER_ROW) % 8   + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
@@ -6358,22 +6358,22 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63586358        threadgroup half * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
63596359
63606360        #pragma  unroll(4)
6361-         for  (int  ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6361+         for  (short  ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
63626362            #pragma  unroll(4)
6363-             for  (int  i = 0 ; i < 4 ; i++) {
6363+             for  (short  i = 0 ; i < 4 ; i++) {
63646364                simdgroup_load (ma[i],lsma + SG_MAT_SIZE * i);
63656365            }
63666366            simdgroup_barrier (mem_flags::mem_none);
63676367            #pragma  unroll(2)
6368-             for  (int  i = 0 ; i < 2 ; i++) {
6368+             for  (short  i = 0 ; i < 2 ; i++) {
63696369                simdgroup_load (mb[i],lsmb + SG_MAT_SIZE * i);
63706370            }
63716371
63726372            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
63736373            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
63746374
63756375            #pragma  unroll(8)
6376-             for  (int  i = 0 ; i < 8 ; i++){
6376+             for  (short  i = 0 ; i < 8 ; i++){
63776377                simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
63786378            }
63796379        }
@@ -6382,7 +6382,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63826382    if  ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
63836383        device float  * C = dst + (BLOCK_SIZE_M * r0 + 32  * (sgitg &  1 )) \
63846384                               + (BLOCK_SIZE_N * r1 + 16  * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6385-         for  (int  i = 0 ; i < 8 ; i++) {
6385+         for  (short  i = 0 ; i < 8 ; i++) {
63866386            //  cast to f32
63876387            simdgroup_float8x8 mc_f32 (1 .0f );
63886388            simdgroup_multiply (mc_f32, mc[i], mc_f32);
@@ -6394,7 +6394,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63946394        threadgroup_barrier (mem_flags::mem_threadgroup);
63956395        threadgroup float  * temp_str = ((threadgroup float  *)shared_memory) \
63966396                                       + 32  * (sgitg&1 ) + (16  * (sgitg>>1 )) * BLOCK_SIZE_M;
6397-         for  (int  i = 0 ; i < 8 ; i++) {
6397+         for  (short  i = 0 ; i < 8 ; i++) {
63986398            simdgroup_float8x8 mc_f32 (1 .0f );
63996399            simdgroup_multiply (mc_f32, mc[i], mc_f32);
64006400            simdgroup_store (mc_f32, temp_str + 8  * (i%4 ) + 8  * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
0 commit comments