From 4a12cf53ec116c06e5d74073b54a3bca6046cb17 Mon Sep 17 00:00:00 2001 From: Octavian Maghiar Date: Mon, 4 Dec 2023 11:13:35 +0000 Subject: [PATCH] [RISC-V] Improve RVV kernel generator LMUL usage The RVV kernel generation script uses the provided LMUL to increase the number of accumulator registers. Since the effect of the LMUL is to group together the vector registers into larger ones, it actually should be used as a multiplier in the calculation of vlenmax. At the moment, no matter what LMUL is provided, the generated kernels would only set the maximum number of vector elements equal to VLEN/SEW. Commit changes the use of LMUL to properly adjust vlenmax. Note that an increase in LMUL results in a decrease in the number of effective vector registers. --- kernel/riscv64/generate_kernel.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/kernel/riscv64/generate_kernel.py b/kernel/riscv64/generate_kernel.py index e2ce97971a..8be7c9f9cc 100755 --- a/kernel/riscv64/generate_kernel.py +++ b/kernel/riscv64/generate_kernel.py @@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ): dest.write("ai += {M}*2;") dest.write() - - accumulation_regs = a_regs * N * settings['LMUL_ACC'].value + # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results + accumulation_regs = a_regs * N dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k", a_regs=a_regs*2, accumulation_regs=accumulation_regs*2 ) pass_regs = (accumulation_regs + a_regs)*2 - tmp_regs = 32-pass_regs + tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs if tmp_regs < 2: raise RuntimeError("Complex kernel would use too many registers!") @@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ): M = settings['M'].value N = settings['N'].value - vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value ) + vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value / + settings['ELEN_PARAM'].value) a_regs = max(int(M/vlenmax), 1) - accumulation_regs = a_regs * N * settings['LMUL_ACC'].value + # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results + accumulation_regs = a_regs * N required_regs = accumulation_regs + a_regs if is_complex: required_regs = required_regs * 2 + 2 @@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ): '''.format(tail_policy=settings['tail_policy'].value)) - if required_regs > 32: - raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available".format( - required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else '') + if required_regs > (32 // settings['LMUL_ACC'].value): + raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format( + required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value )) TRMM = (settings['op'].value == 'trmm') @@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ): def generate_M_tails( dest, settings, M, N ): M_tail = int(M/2) M_tail_min = settings['M_tail_scalar_from'].value - vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value ) + vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value + / settings['ELEN_PARAM'].value ) TRMM = (settings['op'].value == 'trmm') is_complex = settings['complex'].value generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real @@ -667,4 +670,4 @@ def OUTPUT(*args, **kwargs): ERROR("unsupported kernel type {}".format(settings['op'])) if __name__ == "__main__": - main() \ No newline at end of file + main()