Skip to content

Commit

Permalink
[PYTHON] Allow triton.code_gen.Binary to print Triton-IR asm. (triton…
Browse files Browse the repository at this point in the history
  • Loading branch information
daadaada authored and ptillet committed Jul 27, 2021
1 parent 1112e25 commit f668837
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
4 changes: 2 additions & 2 deletions lib/codegen/selection/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,9 +817,9 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
// update accumulators
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
for(unsigned K = 0; K < NK; K += 4)
for(unsigned m = 0; m < num_m/2; m++)
for(unsigned n = 0; n < num_n/2; n++)
for(unsigned K = 0; K < NK; K += 4){
for(unsigned n = 0; n < num_n/2; n++) {
if(has.find({m, K}) == has.end()){
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
Expand Down
5 changes: 4 additions & 1 deletion python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "triton/ir/enums.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
Expand Down Expand Up @@ -78,7 +79,9 @@ void init_triton_codegen(py::module &&m) {
drv::kernel *ker;
size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
return std::make_tuple(mod, ker, shared_mem);
std::stringstream ss;
ir::print(ir, ss);
return std::make_tuple(mod, ker, shared_mem, ss.str());
},
py::return_value_policy::take_ownership);
}
Expand Down
12 changes: 8 additions & 4 deletions python/triton/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,17 @@ def generic_visit(self, node):


class Binary:
def __init__(self, module, kernel, num_warps, shared_mem):
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
# cache ir asm
self.ir_asm = ir_asm
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem
self.num_warps = num_warps

def asm(self, mode):
if mode == 'ttir':
return self.ir_asm
if mode == 'ptx':
return self.module.ptx()
if mode == 'llir':
Expand Down Expand Up @@ -495,8 +499,8 @@ def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem)
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem, ir_asm)

def __call__(self, *wargs, grid, num_warps=4, **meta):
# device inference
Expand Down Expand Up @@ -576,7 +580,7 @@ def __call__(self, *args, **meta):
config = self.cache[key]
else:
config = self.configs[0]
self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)


class JITFunction:
Expand Down

0 comments on commit f668837

Please sign in to comment.