Skip to content

Commit

Permalink
xe: ocl: gemm: remove conditionals in main loop
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler committed Sep 23, 2024
1 parent d9f2f4c commit e852cdd
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/gpu/intel/ocl/gemm/ref_gemm.cl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ void get_strides(int mask, long dim0, long dim1, long dim2, long *str0,
__kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b,
__global C_DATA_T *c, __global BIA_DATA_T *bias, long offset_a0,
long offset_b0, long offset_c0, long offset_bias0, int transa,
int transb, long MB, long M, long N, long K, long stride_a,
long stride_b, long stride_c, long lda, long ldb, long ldc,
int transb, long MB, long M, long N, long K, long stride_a_mb,
long stride_b_mb, long stride_c, long lda, long ldb, long ldc,
float eltwise_alpha, float eltwise_beta, float eltwise_scale,
int bias_mask, __global int *ao, __global int *bo, __global int *c0,
int c0_mask, __global float *scales, long scale_stride, float beta) {
Expand All @@ -62,10 +62,15 @@ __kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b,
c0_mask, MB, M, N, &c0_strides[0], &c0_strides[1], &c0_strides[2]);
#endif

long stride_a_m = transa ? lda : 1;
long stride_a_k = transa ? 1 : lda;
long stride_b_k = transb ? ldb : 1;
long stride_b_n = transb ? 1 : ldb;

ACC_DATA_T acc = 0;
for (long k = 0; k < K; ++k) {
long off_a = mb * stride_a + (transa ? m * lda + k : k * lda + m);
long off_b = mb * stride_b + (transb ? k * ldb + n : n * ldb + k);
long off_a = mb * stride_a_mb + m * stride_a_m + k * stride_a_k;
long off_b = mb * stride_b_mb + k * stride_b_k + n * stride_b_n;
acc += TO_ACC(A_TO_REF(a[off_a]) - ATTR_A0)
* TO_ACC(B_TO_REF(b[off_b]) - ATTR_B0);
}
Expand Down

0 comments on commit e852cdd

Please sign in to comment.