From e852cddca40ebd06d2082079f907d588490191f6 Mon Sep 17 00:00:00 2001 From: Roy Oursler Date: Fri, 13 Sep 2024 11:26:57 -0700 Subject: [PATCH] xe: ocl: gemm: remove conditionals in main loop --- src/gpu/intel/ocl/gemm/ref_gemm.cl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/gpu/intel/ocl/gemm/ref_gemm.cl b/src/gpu/intel/ocl/gemm/ref_gemm.cl index c020a1c63bf..8580c173b7e 100644 --- a/src/gpu/intel/ocl/gemm/ref_gemm.cl +++ b/src/gpu/intel/ocl/gemm/ref_gemm.cl @@ -34,8 +34,8 @@ void get_strides(int mask, long dim0, long dim1, long dim2, long *str0, __kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b, __global C_DATA_T *c, __global BIA_DATA_T *bias, long offset_a0, long offset_b0, long offset_c0, long offset_bias0, int transa, - int transb, long MB, long M, long N, long K, long stride_a, - long stride_b, long stride_c, long lda, long ldb, long ldc, + int transb, long MB, long M, long N, long K, long stride_a_mb, + long stride_b_mb, long stride_c, long lda, long ldb, long ldc, float eltwise_alpha, float eltwise_beta, float eltwise_scale, int bias_mask, __global int *ao, __global int *bo, __global int *c0, int c0_mask, __global float *scales, long scale_stride, float beta) { @@ -62,10 +62,15 @@ __kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b, c0_mask, MB, M, N, &c0_strides[0], &c0_strides[1], &c0_strides[2]); #endif + long stride_a_m = transa ? lda : 1; + long stride_a_k = transa ? 1 : lda; + long stride_b_k = transb ? ldb : 1; + long stride_b_n = transb ? 1 : ldb; + ACC_DATA_T acc = 0; for (long k = 0; k < K; ++k) { - long off_a = mb * stride_a + (transa ? m * lda + k : k * lda + m); - long off_b = mb * stride_b + (transb ? k * ldb + n : n * ldb + k); + long off_a = mb * stride_a_mb + m * stride_a_m + k * stride_a_k; + long off_b = mb * stride_b_mb + k * stride_b_k + n * stride_b_n; acc += TO_ACC(A_TO_REF(a[off_a]) - ATTR_A0) * TO_ACC(B_TO_REF(b[off_b]) - ATTR_B0); }