Skip to content

Commit 5b376c1

Browse files
authored
Revert "[Runtime] Rework constexpr_function to support cache invalidation (#7…"
This reverts commit a977e39.
1 parent abb66fe commit 5b376c1

File tree

21 files changed

+243
-310
lines changed

21 files changed

+243
-310
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def test_aggregate_with_constexpr():
308308
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>
309309

310310

311-
@triton.constexpr_function
311+
@tl.constexpr_function
312312
def constexpr_function(x):
313313
return x + 1
314314

@@ -345,12 +345,12 @@ def test_reassign_aggregate_with_constexpr():
345345
agg = agg.modify(tl.arange(4, 8))
346346

347347

348-
@triton.constexpr_function
348+
@tl.constexpr_function
349349
def make_shape(m, n):
350350
return (m, n)
351351

352352

353-
@triton.constexpr_function
353+
@tl.constexpr_function
354354
def add_shape_dims(m, n):
355355
return m + n
356356

@@ -365,7 +365,7 @@ def test_constexpr_getitem():
365365
tl.arange(4, sum)
366366

367367

368-
@triton.constexpr_function
368+
@tl.constexpr_function
369369
def make_constexpr_closure(x):
370370
x = tl.constexpr(x)
371371

@@ -386,7 +386,7 @@ def test_constexpr_closure():
386386
closure((128, 128))
387387

388388

389-
@triton.constexpr_function
389+
@tl.constexpr_function
390390
def make_constexpr_generator(f):
391391
f = tl.constexpr(f)
392392

@@ -422,7 +422,7 @@ def test_constexpr_generator():
422422
generator(lhs)
423423

424424

425-
@triton.constexpr_function
425+
@tl.constexpr_function
426426
def Box(T):
427427

428428
@tl.core._aggregate

python/test/unit/language/test_tuple.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,17 +217,23 @@ def m_to_the_n(X, shape: tl.constexpr, strides, m_n):
217217

218218
def test_passing_tuple_to_make_tensor_descriptor(device, with_allocator):
219219

220+
from triton.language.core import builtin
221+
222+
@builtin
223+
def is_constexpr(v, _semantic=None):
224+
return isinstance(v, tl.constexpr)
225+
220226
@triton.jit
221227
def m_to_the_n(X_base, shape, strides, m_n, BLOCK_DIM: tl.constexpr):
222-
tl.static_assert(isinstance(strides[1].type, tl.constexpr_type))
228+
tl.static_assert(is_constexpr(strides[1]))
223229
X = tl.make_tensor_descriptor(
224230
X_base,
225231
shape=shape,
226232
strides=strides,
227233
block_shape=[BLOCK_DIM, BLOCK_DIM],
228234
)
229235
# Make sure tl.make_tensor_descriptor didn't modify strides (i.e. didn't unwrap the constexpr)
230-
tl.static_assert(isinstance(strides[1].type, tl.constexpr_type))
236+
tl.static_assert(is_constexpr(strides[1]))
231237
data = X.load([0, 0])
232238
# Include a for loop to ensure strides[1] is lifted into a constexpr
233239
# (otherwise cloning the local scope will fail).

python/test/unit/runtime/test_cache.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -127,42 +127,6 @@ def test_combine_fn_change():
127127
seen_keys.add(key)
128128

129129

130-
@triton.constexpr_function
131-
def constexpr_flag_fn():
132-
return False
133-
134-
135-
@triton.jit
136-
def constexpr_fn_user(out):
137-
a: tl.constexpr = constexpr_flag_fn()
138-
tl.store(out, a)
139-
140-
141-
def test_constexpr_fn_change():
142-
baseline = constexpr_fn_user.cache_key
143-
144-
orig_src = constexpr_flag_fn.src
145-
new_src = orig_src.replace("False", "True")
146-
constexpr_flag_fn._unsafe_update_src(new_src)
147-
constexpr_fn_user.hash = None
148-
updated = constexpr_fn_user.cache_key
149-
assert baseline != updated
150-
151-
constexpr_flag_fn._unsafe_update_src(orig_src)
152-
constexpr_fn_user.hash = None
153-
assert constexpr_fn_user.cache_key == baseline
154-
155-
156-
@triton.constexpr_function
157-
def invalid_constexpr_fn():
158-
return torch.cuda.get_device_capability()
159-
160-
161-
def test_invalid_constexpr_fn():
162-
with pytest.raises(RuntimeError):
163-
invalid_constexpr_fn.cache_key
164-
165-
166130
def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines):
167131
temp_file.write_text(('# extra line\n' * num_extra_lines) + code)
168132
spec = importlib.util.spec_from_file_location("module.name", str(temp_file))

python/triton/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
InterpreterError,
1818
MockTensor,
1919
)
20-
from .runtime.jit import constexpr_function, jit
20+
from .runtime.jit import jit
2121
from .runtime._async_compile import AsyncCompileMode, FutureKernel
2222
from .compiler import compile, CompilationError
2323
from .errors import TritonError
@@ -36,7 +36,6 @@
3636
"CompilationError",
3737
"compile",
3838
"Config",
39-
"constexpr_function",
4039
"FutureKernel",
4140
"heuristics",
4241
"InterpreterError",

python/triton/compiler/code_generator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from .._C.libtriton import ir, gluon_ir
1515
from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
1616
from ..language.core import _unwrap_if_constexpr, base_value, base_type
17+
from ..runtime.jit import get_jit_fn_file_line, get_full_name
1718
# ideally we wouldn't need any runtime component
18-
from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, ConstexprFunction, JITFunction
19+
from ..runtime import JITFunction
1920
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
2021

2122
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
@@ -51,7 +52,7 @@ def _is_triton_tensor(o: Any) -> bool:
5152

5253

5354
def _is_constexpr(o: Any) -> bool:
54-
return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
55+
return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction))
5556

5657

5758
def _is_non_scalar_tensor(o: Any) -> bool:
@@ -395,7 +396,7 @@ def global_lookup(name: str, absent):
395396
val is absent,
396397
name in self.builtin_namespace, #
397398
type(val) is ModuleType, #
398-
isinstance(val, JITCallable), #
399+
isinstance(val, JITFunction), #
399400
getattr(val, "__triton_builtin__", False), #
400401
getattr(val, "__triton_aggregate__", False), #
401402
getattr(val, "__module__", "").startswith("triton.language"), #
@@ -1322,8 +1323,7 @@ def call_Function(self, node, fn, args, kws):
13221323
if isinstance(fn, JITFunction):
13231324
_check_fn_args(node, fn, args)
13241325
return self.call_JitFunction(fn, args, kws)
1325-
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
1326-
fn, ConstexprFunction):
1326+
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
13271327
extra_kwargs = dict()
13281328
sig = inspect.signature(fn)
13291329
if '_semantic' in sig.parameters:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import nvidia
2-
from ._runtime import constexpr_function, jit
2+
from ._runtime import jit
33
from triton.language.core import must_use_result
44

5-
__all__ = ["constexpr_function", "jit", "must_use_result", "nvidia"]
5+
__all__ = ["jit", "must_use_result", "nvidia"]

python/triton/experimental/gluon/_runtime.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22
from triton.compiler.compiler import ASTSource
33
from triton.backends.compiler import Language
4-
from triton.runtime.jit import JITFunction, constexpr_function
4+
from triton.runtime.jit import JITFunction
55
from typing import TypeVar, Optional, Callable, Iterable, Union
66
from triton._C.libtriton import ir
77

88
T = TypeVar("T")
99

10-
__all__ = ["constexpr_function", "jit"]
11-
1210

1311
class GluonASTSource(ASTSource):
1412

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import triton.language.core as tl_core
1313
from triton.language.core import (
1414
constexpr,
15+
constexpr_function,
1516
base_value,
1617
base_type,
1718
dtype,
@@ -78,6 +79,7 @@
7879

7980
__all__ = [
8081
"constexpr",
82+
"constexpr_function",
8183
"base_value",
8284
"base_type",
8385
"dtype",

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass
22
from typing import List, Optional
33
from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
4-
from triton.runtime.jit import constexpr_function
54

65
__all__ = [
76
"AutoLayout",
@@ -261,7 +260,6 @@ def type(self):
261260
return constexpr_type(self)
262261

263262

264-
@constexpr_function
265263
def _get_shape_per_cta(shape, cta_split_num):
266264
shape_per_cta = shape
267265
if cta_split_num is not None:
@@ -325,7 +323,6 @@ def _to_ir(self, builder):
325323
)
326324

327325
@staticmethod
328-
@constexpr_function
329326
def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas_per_cga=None, cta_split_num=None,
330327
cta_order=None):
331328
"""Returns an NVMMASharedLayout with default swizzling for a given shape.

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from typing import Optional, Tuple, List, TYPE_CHECKING
33

44
from dataclasses import dataclass
5-
from triton.runtime.jit import constexpr_function
5+
import triton
66
from triton.experimental.gluon.language import _core as ttgl
7-
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
7+
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr, constexpr_function
88
from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta
99
from triton.experimental.gluon.language._semantic import _check
1010

@@ -62,11 +62,6 @@ def mangle(self) -> str:
6262
return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
6363

6464

65-
@constexpr_function
66-
def _cdiv(x, div):
67-
return (x + div - 1) // div
68-
69-
7065
@constexpr_function
7166
def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None):
7267
"""Returns a BlockedLayout compatible with load/store on tensor memory with the 32x32b instruction variant.
@@ -82,19 +77,19 @@ def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_sp
8277
if M == 64:
8378
threads_per_warp = [16, 2]
8479
if num_blocks == 1:
85-
size_per_thread = [1, _cdiv(N, num_warp_groups * 2)]
80+
size_per_thread = [1, triton.cdiv(N, num_warp_groups * 2)]
8681
warps_per_cta = [4, num_warp_groups]
8782
else:
88-
size_per_thread = [1, _cdiv(N, 2)]
83+
size_per_thread = [1, triton.cdiv(N, 2)]
8984
warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
90-
warps_per_cta.append(_cdiv(num_warp_groups, warps_per_cta[0] // 4))
85+
warps_per_cta.append(triton.cdiv(num_warp_groups, warps_per_cta[0] // 4))
9186
else:
9287
if shape[0] > 128:
9388
size_per_thread = [1, N]
9489
threads_per_warp = [32, 1]
9590
warps_per_cta = [4 * num_warp_groups, 1]
9691
else:
97-
size_per_thread = [1, _cdiv(N, num_warp_groups)]
92+
size_per_thread = [1, triton.cdiv(N, num_warp_groups)]
9893
threads_per_warp = [32, 1]
9994
warps_per_cta = [4, num_warp_groups]
10095
return BlockedLayout(

0 commit comments

Comments
 (0)