Skip to content

Commit 257178c

Browse files
ArinaJJHzhzhcookie
andauthored
[BACKEND] Add passes to metax (#47)
--------- Co-authored-by: zhengyang <zhengyang@baai.ac.cn>
1 parent c03aac4 commit 257178c

File tree

3 files changed

+75
-14
lines changed

3 files changed

+75
-14
lines changed

third_party/metax/backend/compiler.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def make_ttgir(mod, metadata, opt, capability):
212212

213213
if opt.pipeline == "cpasync":
214214
disable_prefetch = True
215+
metax.passes.ttgpuir.add_tritonmetaxgpu_change_layout_for_int8_pass(pm, opt.num_stages, opt.pipeline)
215216
metax.passes.ttgpuir.add_accelerate_matmul(pm, opt.num_stages, disable_prefetch, store_coalesce, "c500")
216217
passes.ttgpuir.add_remove_layout_conversions(pm)
217218
if store_coalesce:
@@ -236,8 +237,11 @@ def make_ttgir(mod, metadata, opt, capability):
236237
metax.passes.ttgpuir.add_pipeline_async_tt(pm, opt.num_stages)
237238
metax.passes.ttgpuir.add_pipeline_async_base(pm, opt.num_stages, fullstage)
238239
elif mla and opt.num_stages == 2 and opt.pipeline == "cpasync":
239-
metax.passes.ttgpuir.add_pipeline_async_multidot_mla(pm, opt.num_stages, fullstage,
240-
opt.pipeline_load_num)
240+
metax.passes.ttgpuir.add_pipeline_async_multidot_mla_mixed(pm, opt.num_stages, fullstage,
241+
opt.pipeline_load_num, single_shm, True)
242+
elif mla and opt.num_stages == 2 and opt.pipeline == "mixed":
243+
metax.passes.ttgpuir.add_pipeline_async_multidot_mla_mixed(pm, opt.num_stages, fullstage,
244+
opt.pipeline_load_num, single_shm, False)
241245
else:
242246
print("no avalilable pipeline for maca")
243247
else:
@@ -252,7 +256,7 @@ def make_ttgir(mod, metadata, opt, capability):
252256
passes.ttgpuir.add_reorder_instructions(pm)
253257
if os.getenv("TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP"):
254258
metax.passes.ttgpuir.add_tritonmetaxgpu_move_dot_operands_out_loop_pass(pm)
255-
if os.getenv("TRITON_ENABLE_MACA_MERGE_EQUAL_SHARED_LAYOUT"):
259+
if not os.getenv("TRITON_DISABLE_MACA_MERGE_EQUAL_SHARED_LAYOUT"):
256260
metax.passes.ttgpuir.add_tritonmetaxgpu_merge_equal_shared_layout_pass(pm)
257261
passes.common.add_cse(pm)
258262
passes.common.add_symbol_dce(pm)
@@ -322,14 +326,38 @@ def make_mcfatbin(src, metadata, opt, capability):
322326
if "roll" not in scenarios:
323327
compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " "
324328
elif opt.pipeline == "cpasync" and "mla" not in scenarios:
325-
compile_options = " -mllvm -metaxgpu-sched-regpressure=true -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true \
326-
-mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true -mllvm -metaxgpu-shl-add-combine=false \
327-
-mllvm -misched-postra=true -mllvm -enable-post-misched=true "
329+
compile_options = " -mllvm -metaxgpu-sched-regpressure=true "
330+
compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true \
331+
-mllvm -metaxgpu-shl-add-combine=false -mllvm -misched-postra=true -mllvm -enable-post-misched=true "
328332

329333
if os.getenv("TRITON_ENABLE_MACA_COMPILER_INT8_OPT"):
330334
compile_options += " -mllvm -metaxgpu-slp-vectorize-i8=true"
335+
331336
if "unroll" in scenarios:
332337
compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " "
338+
if "flashattn-fwd" in scenarios:
339+
compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -metaxgpu-sched-select=metaxgpu-minreg -mllvm -map-use-pk-fma=1 "
340+
elif "flashattn-bwd" in scenarios:
341+
compile_options = " -mllvm -metaxgpu-sched-regpressure=true "
342+
compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true "
343+
if "mla" in scenarios:
344+
# maybe will change the compile options in mla later
345+
if opt.num_stages == 2:
346+
if opt.pipeline == "cpasync":
347+
compile_options = " -mllvm -metaxgpu-sched-regpressure=true "
348+
compile_options += " -mllvm -metaxgpu-sinkload=false -mllvm -metaxgpu-vectorize-slp=true -mllvm -metaxgpu-igroup -mllvm -metaxgpu-aggressive-4g-addr-opt=true \
349+
-mllvm -metaxgpu-shl-add-combine=false -mllvm -misched-postra=true -mllvm -enable-post-misched=true "
350+
351+
if "unroll" in scenarios:
352+
compile_options += " -mllvm -metaxgpu-mma-unroll-count=" + str(opt.num_stages) + " "
353+
elif opt.pipeline == "basic" or opt.pipeline == "mixed":
354+
compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -map-use-pk-fma=1 -mllvm -metaxgpu-split-regalloc=true -mllvm -metaxgpu-aggressive-fold=true \
355+
-mllvm -metaxgpu-disable-licm=true "
356+
357+
else:
358+
assert False, "Please set pipeline for mla!"
359+
else:
360+
compile_options = " -mllvm -metaxgpu-mma-sched=true -mllvm -map-use-pk-fma=1 -mllvm -metaxgpu-split-regalloc=true -mllvm -metaxgpu-aggressive-fold=true "
333361
if opt.extra_options != "":
334362
compile_options = opt.extra_options
335363
return metax.translate_llvmir_to_mcfatbin(src, mxcc_arch, os.environ.get('MACA_PATH'), compile_options)

third_party/metax/backend/driver.c

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
106106
MCcontext pctx = 0;
107107

108108
Py_BEGIN_ALLOW_THREADS;
109+
// TODO: MCcontext implement not found
109110
MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxGetCurrent(&pctx));
110111
if (!pctx) {
111112
MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
@@ -121,7 +122,6 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
121122
MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
122123
mcFuncGetAttribute(&n_spills, MC_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
123124
n_spills /= 4;
124-
125125
Py_END_ALLOW_THREADS;
126126

127127
if (PyErr_Occurred()) {
@@ -141,6 +141,36 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
141141
return NULL;
142142
}
143143

144+
Py_BEGIN_ALLOW_THREADS;
145+
146+
// Ensure we have an active context.
147+
// MCcontext ctx = NULL;
148+
// TODO: CU_LIMIT_PRINTF_FIFO_SIZE implement not found
149+
// MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxGetCurrent(&ctx));
150+
// if (!ctx) {
151+
// MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
152+
// mcDevicePrimaryCtxRetain(&ctx, /*device=*/0));
153+
// MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(mcCtxSetCurrent(ctx));
154+
// }
155+
156+
// // We can't set the fifo size after running a kernel that calls printf.
157+
// This
158+
// // is true even if the set() call is a nop and the new size is the same as
159+
// the
160+
// // old size.
161+
// //
162+
// // This is unfriendly, so check if the old size matches the new size, and
163+
// skip
164+
// // the set() call if so.
165+
// size_t oldSize = 0;
166+
// MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
167+
// mcCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE));
168+
// if (oldSize != size) {
169+
// MACA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
170+
// mcCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size));
171+
// }
172+
173+
Py_END_ALLOW_THREADS;
144174
return Py_None;
145175
}
146176

third_party/metax/backend/driver.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def format_of(ty):
143143
#include <stdbool.h>
144144
#include <Python.h>
145145
#include <dlfcn.h>
146+
#include <stdlib.h> // MACA: for getenv
146147
147148
static inline void gpuAssert(mcError_t code, const char *file, int line)
148149
{{
@@ -201,14 +202,16 @@ def format_of(ty):
201202
ptr_info.dev_ptr = (mcDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
202203
if(!ptr_info.dev_ptr)
203204
return ptr_info;
204-
uint64_t dev_ptr;
205-
int status = mcPointerGetAttribute(&dev_ptr, mcPointerAttributeDevicePointer, ptr_info.dev_ptr);
206-
if (status == mcErrorInvalidValue) {{
207-
PyErr_Format(PyExc_ValueError,
208-
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
209-
ptr_info.valid = false;
205+
if (getenv("TRITON_DISABLE_DEVICE_POINTER_ATTR_CHECK") == NULL) {{
206+
uint64_t dev_ptr;
207+
int status = mcPointerGetAttribute(&dev_ptr, mcPointerAttributeDevicePointer, ptr_info.dev_ptr);
208+
if (status == mcErrorInvalidValue) {{
209+
PyErr_Format(PyExc_ValueError,
210+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
211+
ptr_info.valid = false;
212+
}}
213+
ptr_info.dev_ptr = (mcDeviceptr_t)dev_ptr;
210214
}}
211-
ptr_info.dev_ptr = (mcDeviceptr_t)dev_ptr;
212215
Py_DECREF(ret); // Thanks ChatGPT!
213216
return ptr_info;
214217
}}

0 commit comments

Comments
 (0)