Skip to content

Commit e8b1409

Browse files
Revert "[user triton] typing triton_kernel_wrap.py (pytorch#138230)"
This reverts commit 2f61b69. Reverted pytorch#138230 on behalf of https://github.com/wdvr due to Reverting this, as it started failing tests on main ([comment](pytorch#138230 (comment)))
1 parent 4632594 commit e8b1409

File tree

4 files changed

+136
-278
lines changed

4 files changed

+136
-278
lines changed

torch/_dynamo/variables/builder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,18 @@ def _id_dispatch(
527527

528528
def _wrap(self, value):
529529
# import here to avoid circular dependencies
530-
from torch.utils._triton import Autotuner, has_triton_tma, JITFunction
530+
from torch.utils._triton import has_triton, has_triton_tma
531+
532+
if has_triton():
533+
from triton.runtime.autotuner import Autotuner
534+
from triton.runtime.jit import JITFunction
535+
else:
536+
537+
class JITFunction:
538+
pass
539+
540+
class Autotuner:
541+
pass
531542

532543
if has_triton_tma():
533544
from triton.tools.experimental_descriptor import (

torch/_dynamo/variables/functions.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,7 @@
55
import inspect
66
import itertools
77
import types
8-
from typing import (
9-
Any,
10-
Callable,
11-
Dict,
12-
List,
13-
Optional,
14-
Tuple,
15-
TYPE_CHECKING,
16-
TypeVar,
17-
Union,
18-
)
19-
from typing_extensions import Never
8+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar, Union
209

2110
import torch
2211

@@ -47,8 +36,6 @@
4736
if TYPE_CHECKING:
4837
from torch._dynamo.symbolic_convert import InstructionTranslator
4938
from torch._guards import Source
50-
from torch._higher_order_ops.triton_kernel_wrap import TritonGridType
51-
from torch.utils._triton import TritonKernelType
5239

5340

5441
_F = TypeVar("_F", bound=Callable)
@@ -1041,18 +1028,18 @@ def as_python_constant(self):
10411028

10421029

10431030
class DynamoTritonHOPifier(TritonHOPifier):
1044-
def raise_unsupported(self, msg: str) -> Never:
1031+
def raise_unsupported(self, msg):
10451032
raise Unsupported(msg)
10461033

1047-
def is_callable(self, maybe_callable: Any) -> bool:
1034+
def is_callable(self, maybe_callable):
10481035
return isinstance(
10491036
maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable)
10501037
)
10511038

1052-
def get_value(self, val: Any) -> Any:
1039+
def get_value(self, val):
10531040
return val.value
10541041

1055-
def check_grid(self, grid) -> Tuple[torch.fx.proxy.Proxy, ...]:
1042+
def check_grid(self, grid):
10561043
from .lists import BaseListVariable
10571044

10581045
if isinstance(grid, BaseListVariable):
@@ -1065,7 +1052,7 @@ def call_grid(self, grid, meta, tx):
10651052
grid = grid.call_function(tx, [meta], {})
10661053
return grid
10671054

1068-
def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
1055+
def call_HOP(self, variable, grids, combined_args_raw, tx):
10691056
from .constant import ConstantVariable
10701057
from .dicts import ConstDictVariable
10711058

@@ -1136,10 +1123,6 @@ def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
11361123

11371124

11381125
class TritonKernelVariable(VariableTracker):
1139-
grid: "TritonGridType"
1140-
kernel: "TritonKernelType"
1141-
kernel_idx: Optional[int]
1142-
11431126
def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
11441127
super().__init__(**kwargs)
11451128
dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)

0 commit comments

Comments
 (0)