|
14 | 14 | from .._C.libtriton import ir, gluon_ir |
15 | 15 | from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple |
16 | 16 | from ..language.core import _unwrap_if_constexpr, base_value, base_type |
| 17 | +from ..runtime.jit import get_jit_fn_file_line, get_full_name |
17 | 18 | # ideally we wouldn't need any runtime component |
18 | | -from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, ConstexprFunction, JITFunction |
| 19 | +from ..runtime import JITFunction |
19 | 20 | from .._utils import find_paths_if, get_iterable_path, set_iterable_path |
20 | 21 |
|
21 | 22 | from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) |
@@ -51,7 +52,7 @@ def _is_triton_tensor(o: Any) -> bool: |
51 | 52 |
|
52 | 53 |
|
53 | 54 | def _is_constexpr(o: Any) -> bool: |
54 | | - return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable)) |
| 55 | + return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction)) |
55 | 56 |
|
56 | 57 |
|
57 | 58 | def _is_non_scalar_tensor(o: Any) -> bool: |
@@ -395,7 +396,7 @@ def global_lookup(name: str, absent): |
395 | 396 | val is absent, |
396 | 397 | name in self.builtin_namespace, # |
397 | 398 | type(val) is ModuleType, # |
398 | | - isinstance(val, JITCallable), # |
| 399 | + isinstance(val, JITFunction), # |
399 | 400 | getattr(val, "__triton_builtin__", False), # |
400 | 401 | getattr(val, "__triton_aggregate__", False), # |
401 | 402 | getattr(val, "__module__", "").startswith("triton.language"), # |
@@ -1322,8 +1323,7 @@ def call_Function(self, node, fn, args, kws): |
1322 | 1323 | if isinstance(fn, JITFunction): |
1323 | 1324 | _check_fn_args(node, fn, args) |
1324 | 1325 | return self.call_JitFunction(fn, args, kws) |
1325 | | - if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance( |
1326 | | - fn, ConstexprFunction): |
| 1326 | + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): |
1327 | 1327 | extra_kwargs = dict() |
1328 | 1328 | sig = inspect.signature(fn) |
1329 | 1329 | if '_semantic' in sig.parameters: |
|
0 commit comments