Skip to content
12 changes: 12 additions & 0 deletions config/igemm_bwd_generate.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[codegen]
arch = 'gfx908'
code_object = 'cov3'
mode = 'generate'
direction = 'bwd'

[codegen_config]
nxb = '4,1'
nxe = '0,1'
gemm_k = '16,8,4'
micro_tile_with_gemm_k_4 = '16x32,32x16'
precision = 'fp32'
1 change: 1 addition & 0 deletions igemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .codegen import *
from .algo import *
from .igemm_codegen_driver import *
from .igemm_config_gen_driver import *

if sys.hexversion < 0x30600f0:
print("must use python 3.6+. current is {}".format(sys.version))
Expand Down
58 changes: 33 additions & 25 deletions igemm/algo/xdlops_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,37 +259,45 @@ def serialize(self):
# mt_m,mt_n,wt_m,wt_n,wt_k,ws,r_m,r_n,s_m,s_n, inst_mfma
ctrl_xdlops_mapping_fp32 = [
# ctrl_xdlops_mapping_t( 256, 256, 32, 64, 4, 2, 2, 2, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 256, 128, 64, 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 128, 256, 32, 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 256, 128, 64, 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 128, 256, 32, 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 256, 64 , 64, 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), #add by jane

ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 64 , 256, 32, 64, 1, 4, 1, 1, 1, 2, v_mfma_f32_32x32x1f32), #add by jane
ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32),
#ctrl_xdlops_mapping_t( 32 , 256, 16 , 64, 1, 4, 1, 1, 1, 2, v_mfma_f32_16x16x1f32), #add by jane, can not because coleasing group assert
ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),

#ctrl_xdlops_mapping_t( 256, 16 , 64, 16, 2, 1, 1, 2, 1, v_mfma_f32_16x16x1f32), # TODO: this will fail in coalescing
#ctrl_xdlops_mapping_t( 16 , 256, 16, 64, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), # TODO: this will fail in coalescing

ctrl_xdlops_mapping_t( 128, 128, 32, 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 128, 128, 32, 32, 2, 4, 2, 2, 1, 1, v_mfma_f32_32x32x2f32),
ctrl_xdlops_mapping_t( 128, 128, 32, 64, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 128, 128, 32, 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 128, 128, 32, 32, 2, 4, 2, 2, 1, 1, v_mfma_f32_32x32x2f32),
ctrl_xdlops_mapping_t( 128, 128, 32, 64, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 128, 64 , 64, 32 , 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), #add by jane
ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32),
ctrl_xdlops_mapping_t( 128, 32 , 32, 8, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 128, 32 , 64, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane, it's better
ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better
#ctrl_xdlops_mapping_t( 128, 16 , 64, 4 , 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32),
#ctrl_xdlops_mapping_t( 16 , 128, 4 , 64, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32),
ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32),
ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 64 , 32 , 32, 32, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better
ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32),
ctrl_xdlops_mapping_t( 32 , 64 , 32 , 32, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane;
ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32),
#ctrl_xdlops_mapping_t( 256, 4 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm
#ctrl_xdlops_mapping_t( 4 , 256, 4 , 64, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm
ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32),
Expand Down
2 changes: 1 addition & 1 deletion igemm/igemm_codegen_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_kernel_per_inc_file_name(ker, origin_file_name):
self.mc.emitter = emitter_per_inc_dict[kpi_file_name]
if IGEMM_EMIT_KERNEL_METADATA_PER_INC_FILE:
kinfo_per_inc_dict[kpi_file_name].append(kernel.get_kernel_info())

self._emit(';----------------------------------------------------------')
self._emit('; starting of kernel {}'.format(kernel.name()))
self._emit(kernel.tunable.serialize())
Expand Down
Loading