5
5
import inspect
6
6
import itertools
7
7
import types
8
- from typing import (
9
- Any ,
10
- Callable ,
11
- Dict ,
12
- List ,
13
- Optional ,
14
- Tuple ,
15
- TYPE_CHECKING ,
16
- TypeVar ,
17
- Union ,
18
- )
19
- from typing_extensions import Never
8
+ from typing import Any , Callable , Dict , List , Optional , TYPE_CHECKING , TypeVar , Union
20
9
21
10
import torch
22
11
47
36
if TYPE_CHECKING :
48
37
from torch ._dynamo .symbolic_convert import InstructionTranslator
49
38
from torch ._guards import Source
50
- from torch ._higher_order_ops .triton_kernel_wrap import TritonGridType
51
- from torch .utils ._triton import TritonKernelType
52
39
53
40
54
41
_F = TypeVar ("_F" , bound = Callable )
@@ -1041,18 +1028,18 @@ def as_python_constant(self):
1041
1028
1042
1029
1043
1030
class DynamoTritonHOPifier (TritonHOPifier ):
1044
- def raise_unsupported (self , msg : str ) -> Never :
1031
+ def raise_unsupported (self , msg ) :
1045
1032
raise Unsupported (msg )
1046
1033
1047
- def is_callable (self , maybe_callable : Any ) -> bool :
1034
+ def is_callable (self , maybe_callable ) :
1048
1035
return isinstance (
1049
1036
maybe_callable , (NestedUserFunctionVariable , UserFunctionVariable )
1050
1037
)
1051
1038
1052
- def get_value (self , val : Any ) -> Any :
1039
+ def get_value (self , val ) :
1053
1040
return val .value
1054
1041
1055
- def check_grid (self , grid ) -> Tuple [ torch . fx . proxy . Proxy , ...] :
1042
+ def check_grid (self , grid ):
1056
1043
from .lists import BaseListVariable
1057
1044
1058
1045
if isinstance (grid , BaseListVariable ):
@@ -1065,7 +1052,7 @@ def call_grid(self, grid, meta, tx):
1065
1052
grid = grid .call_function (tx , [meta ], {})
1066
1053
return grid
1067
1054
1068
- def call_HOP (self , variable , grids , combined_args_raw , tx ) -> ConstantVariable :
1055
+ def call_HOP (self , variable , grids , combined_args_raw , tx ):
1069
1056
from .constant import ConstantVariable
1070
1057
from .dicts import ConstDictVariable
1071
1058
@@ -1136,10 +1123,6 @@ def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
1136
1123
1137
1124
1138
1125
class TritonKernelVariable (VariableTracker ):
1139
- grid : "TritonGridType"
1140
- kernel : "TritonKernelType"
1141
- kernel_idx : Optional [int ]
1142
-
1143
1126
def __init__ (self , kernel , kernel_idx , grid , ** kwargs ) -> None :
1144
1127
super ().__init__ (** kwargs )
1145
1128
dynamo_triton_hopifier_singleton .init_variable (self , kernel , kernel_idx , grid )
0 commit comments