@@ -212,6 +212,7 @@ def make_ttgir(mod, metadata, opt, capability):
212
212
213
213
if opt .pipeline == "cpasync" :
214
214
disable_prefetch = True
215
+ metax .passes .ttgpuir .add_tritonmetaxgpu_change_layout_for_int8_pass (pm , opt .num_stages , opt .pipeline )
215
216
metax .passes .ttgpuir .add_accelerate_matmul (pm , opt .num_stages , disable_prefetch , store_coalesce , "c500" )
216
217
passes .ttgpuir .add_remove_layout_conversions (pm )
217
218
if store_coalesce :
@@ -236,8 +237,11 @@ def make_ttgir(mod, metadata, opt, capability):
236
237
metax .passes .ttgpuir .add_pipeline_async_tt (pm , opt .num_stages )
237
238
metax .passes .ttgpuir .add_pipeline_async_base (pm , opt .num_stages , fullstage )
238
239
elif mla and opt .num_stages == 2 and opt .pipeline == "cpasync" :
239
- metax .passes .ttgpuir .add_pipeline_async_multidot_mla (pm , opt .num_stages , fullstage ,
240
- opt .pipeline_load_num )
240
+ metax .passes .ttgpuir .add_pipeline_async_multidot_mla_mixed (pm , opt .num_stages , fullstage ,
241
+ opt .pipeline_load_num , single_shm , True )
242
+ elif mla and opt .num_stages == 2 and opt .pipeline == "mixed" :
243
+ metax .passes .ttgpuir .add_pipeline_async_multidot_mla_mixed (pm , opt .num_stages , fullstage ,
244
+ opt .pipeline_load_num , single_shm , False )
241
245
else :
242
246
print ("no avalilable pipeline for maca" )
243
247
else :
@@ -252,7 +256,7 @@ def make_ttgir(mod, metadata, opt, capability):
252
256
passes .ttgpuir .add_reorder_instructions (pm )
253
257
if os .getenv ("TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP" ):
254
258
metax .passes .ttgpuir .add_tritonmetaxgpu_move_dot_operands_out_loop_pass (pm )
255
- if os .getenv ("TRITON_ENABLE_MACA_MERGE_EQUAL_SHARED_LAYOUT " ):
259
+ if not os .getenv ("TRITON_DISABLE_MACA_MERGE_EQUAL_SHARED_LAYOUT " ):
256
260
metax .passes .ttgpuir .add_tritonmetaxgpu_merge_equal_shared_layout_pass (pm )
257
261
passes .common .add_cse (pm )
258
262
passes .common .add_symbol_dce (pm )
@@ -322,14 +326,38 @@ def make_mcfatbin(src, metadata, opt, capability):
322
326
if "roll" not in scenarios :
323
327
compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str (opt .num_stages ) + " "
324
328
elif opt .pipeline == "cpasync" and "mla" not in scenarios :
325
- compile_options = " -mllvm -metaxgpu-sched-regpressure=true -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true \
326
- -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true -mllvm -metaxgpu-shl-add-combine=false \
327
- -mllvm -misched-postra=true -mllvm -enable-post-misched=true "
329
+ compile_options = " -mllvm -metaxgpu-sched-regpressure=true "
330
+ compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true \
331
+ -mllvm -metaxgpu-shl-add-combine=false -mllvm - misched-postra=true -mllvm -enable-post-misched=true "
328
332
329
333
if os .getenv ("TRITON_ENABLE_MACA_COMPILER_INT8_OPT" ):
330
334
compile_options += " -mllvm -metaxgpu-slp-vectorize-i8=true"
335
+
331
336
if "unroll" in scenarios :
332
337
compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str (opt .num_stages ) + " "
338
+ if "flashattn-fwd" in scenarios :
339
+ compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -metaxgpu-sched-select=metaxgpu-minreg -mllvm -map-use-pk-fma=1 "
340
+ elif "flashattn-bwd" in scenarios :
341
+ compile_options = " -mllvm -metaxgpu-sched-regpressure=true "
342
+ compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true "
343
+ if "mla" in scenarios :
344
+ # maybe will change the compile options in mla later
345
+ if opt .num_stages == 2 :
346
+ if opt .pipeline == "cpasync" :
347
+ compile_options = " -mllvm -metaxgpu-sched-regpressure=true "
348
+ compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true \
349
+ -mllvm -metaxgpu-shl-add-combine=false -mllvm -misched-postra=true -mllvm -enable-post-misched=true "
350
+
351
+ if "unroll" in scenarios :
352
+ compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str (opt .num_stages ) + " "
353
+ elif opt .pipeline == "basic" or opt .pipeline == "mixed" :
354
+ compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -map-use-pk-fma=1 -mllvm -metaxgpu-split-regalloc=true -mllvm -metaxgpu-aggressive-fold=true \
355
+ -mllvm -metaxgpu-disable-licm=true "
356
+
357
+ else :
358
+ assert False , "Please set pipeline for mla!"
359
+ else :
360
+ compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -map-use-pk-fma=1 -mllvm -metaxgpu-split-regalloc=true -mllvm -metaxgpu-aggressive-fold=true "
333
361
if opt .extra_options != "" :
334
362
compile_options = opt .extra_options
335
363
return metax .translate_llvmir_to_mcfatbin (src , mxcc_arch , os .environ .get ('MACA_PATH' ), compile_options )
0 commit comments