Skip to content

Commit

Permalink
[FRONTEND] enable construction of named tuples inside triton functions (
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Jan 3, 2025
1 parent dc261bf commit 37817d7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
22 changes: 13 additions & 9 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,23 @@ class Tensor(NamedTuple):


@triton.jit
def _namedtuple_kernel(closure, X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# load x
mask_x = (offs_m[:, None] < X.shape[0]) & (offs_n[None, :] < X.shape[1])
mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1])
return mask


@triton.jit
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride)
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
x = tl.load(Xs, mask=mask_x, other=0)
# compute y
y = closure.fn(x, *closure.captured)
# store y
mask_y = (offs_m[:, None] < Y.shape[0]) & (offs_n[None, :] < Y.shape[1])
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
tl.store(Ys, y, mask=mask_y)
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
y = closure.fn(x, *closure.captured)
tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N))


def test_namedtuple(device="cuda"):
Expand Down
9 changes: 9 additions & 0 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def _is_constexpr_global(self, name):

return False

def _is_namedtuple(self, val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")

def _define_name_lookup(self):

def local_lookup(name: str, absent):
Expand All @@ -333,6 +336,7 @@ def global_lookup(name: str, absent):
getattr(val, "__triton_builtin__", False), #
getattr(val, "__module__", "").startswith("triton.language"), #
isinstance(val, language.dtype), #
self._is_namedtuple(val),
self._is_constexpr_global(name), #
# Allow accesses to globals while visiting an ast.arg
# because you should be able to do
Expand Down Expand Up @@ -535,6 +539,11 @@ def assignTarget(self, target, value):
def visit_Assign(self, node):
# construct values to assign
def _sanitize_value(value):
if self._is_namedtuple(type(value)):
vals = [_sanitize_value(v) for v in value]
types = [v.type for v in vals]
fields = type(value)._fields
return language.tuple(vals, language.tuple_type(types, fields))
if isinstance(value, language.tuple):
return language.tuple([_sanitize_value(v) for v in value.values])
native_nontensor_types = (language.dtype, language.tuple)
Expand Down

0 comments on commit 37817d7

Please sign in to comment.