11from typing import Any , Callable , Dict , Optional , Tuple , Union
2+ from functools import partial
3+ import copy
4+ import collections
25import torch
36import torch .fx
7+ from torch .fx .experimental .symbolic_shapes import ShapeEnv
8+ from torch ._guards import Source
49import torch ._inductor .compile_fx
510import torch ._dynamo .backends .torchxla
6- from .utils import NO_LD_PRELOAD_CTX
11+ import torch .fx .immutable_collections as fx_immutable
12+ from torch ._dispatch .python import enable_python_dispatcher
13+ from torch import SymInt , SymFloat , SymBool
14+ from torch .fx .experimental .symbolic_shapes import Symbol
15+ from sympy .printing .str import StrPrinter
16+ import sympy
17+ from .no_preload import NO_LD_PRELOAD_CTX
718from . import config
19+ from .utils import ScalarType
20+ from .pycode_generator import GuardFnCodegen
21+ from .store_pos import StorePos , StoreNegate , StoreInAttr , StoreInIndex
22+ from . import variables as vs
823
924BaseArgumentTypes = Union [
1025 str ,
@@ -35,6 +50,48 @@ def backend_compile(gm: torch.fx.GraphModule,
3550 raise RuntimeError (f"Unknown backend: { backend } " )
3651
3752
53+ def guard_check_shapeenv (inputs : list [torch .Tensor ], fake_inputs : list [Any ],
54+ shape_env : ShapeEnv ) -> bool :
55+ symbol2value : dict [Symbol , Any ] = {}
56+ for fake_input , input in zip (fake_inputs , inputs ):
57+ if isinstance (fake_input , torch ._subclasses .FakeTensor ):
58+ assert isinstance (input , torch .Tensor )
59+ if len (input .shape ) != len (fake_input .shape ):
60+ return False
61+ for symbol , value in zip (fake_input .shape , input .shape ):
62+ expr = symbol .node .expr
63+ if expr in symbol2value :
64+ if symbol2value [expr ] != value :
65+ print ("false due to shape" , fake_input .shape ,
66+ input .shape )
67+ print ("symbol2value" , symbol2value [expr ])
68+ return False
69+ else :
70+ symbol2value [expr ] = value
71+ else :
72+ raise NotImplementedError
73+ for guard in shape_env .guards :
74+ val = guard .expr .subs (symbol2value )
75+ if not (val is sympy .true ):
76+ print ("guard fail" , guard , symbol2value )
77+ return False
78+ return True
79+
80+
81+ class ShapeGuardPrinter (StrPrinter ): # type: ignore[misc]
82+
83+ def __init__ (self , symbol_to_source : Dict [Symbol , list [StorePos ]]):
84+ super ().__init__ ()
85+ self .symbol_to_source = symbol_to_source
86+
87+ def _print_Symbol (self , expr : Symbol ) -> str :
88+ assert isinstance (expr , Symbol ), str (type (expr ))
89+ assert expr in self .symbol_to_source , (
90+ f"{ expr } (could be from { [s .name () for s in expr .sources ]} ) "
91+ f"not in { self .symbol_to_source } " )
92+ return str (self .symbol_to_source [expr ][0 ])
93+
94+
3895class FxGraph :
3996 root : torch .nn .Module
4097 result_graph : torch .fx .Graph
@@ -47,9 +104,78 @@ def __init__(self, root: torch.nn.Module,
47104 self .root = root
48105 self .result_graph = torch .fx .Graph (root )
49106 self .mark_written_fn = mark_written_fn
50- self .fake_mode = torch ._subclasses .FakeTensorMode ()
107+ self .dynamic_shape = config .get_config ('dynshape' )
108+ self .fake_mode = torch ._subclasses .FakeTensorMode (
109+ shape_env = ShapeEnv () if self .dynamic_shape else None ,
110+ # allow_non_fake_inputs=True
111+ )
51112 self .example_inputs = []
52113
114+ def infer_fake_value (self , node : torch .fx .Node ) -> None :
115+
116+ def wrap_fake_exception (fn : Callable [[], Any ]) -> Any :
117+ try :
118+ return fn ()
119+ except torch ._subclasses .UnsupportedFakeTensorException as e :
120+ msg = f"Unsupported: { e .reason } with fake tensor propagation."
121+ raise NotImplementedError (msg ) from e
122+
123+ def as_fake_args_kwargs (
124+ args : Tuple [Any , ...],
125+ kwargs : Dict [str , Any ]) -> Tuple [Any , Dict [str , Any ]]:
126+
127+ def as_fake (arg : Any ) -> Any :
128+ if isinstance (arg , (tuple , list )):
129+ return fx_immutable .immutable_list (
130+ [as_fake (x ) for x in arg ])
131+ if isinstance (arg , slice ):
132+ return slice (as_fake (arg .start ), as_fake (arg .stop ),
133+ as_fake (arg .step ))
134+ if isinstance (arg , torch .fx .Node ):
135+ return arg .meta ["fake" ]
136+ else :
137+ return arg
138+
139+ fake_args = tuple (as_fake (arg ) for arg in args )
140+ fake_kwargs = {k : as_fake (v ) for k , v in kwargs .items ()}
141+ return fake_args , fake_kwargs
142+
143+ def fetch_attr (target : str ) -> Any :
144+ target_atoms = target .split ('.' )
145+ attr_itr = self .root
146+ for i , atom in enumerate (target_atoms ):
147+ if not hasattr (attr_itr , atom ):
148+ raise RuntimeError (
149+ f"Node referenced nonexistent target { '.' .join (target_atoms [:i ])} "
150+ )
151+ attr_itr = getattr (attr_itr , atom )
152+ return attr_itr
153+
154+ fake_args , fake_kwargs = as_fake_args_kwargs (node .args , node .kwargs )
155+ fake : Any = None
156+ op = node .op
157+ assert op not in ("placeholder" , "output" )
158+ if op == "get_attr" :
159+ with self .fake_mode , enable_python_dispatcher ():
160+ param = fetch_attr (node .target )
161+ fake = self .fake_mode .from_tensor (param , static_shapes = True )
162+ elif op == "call_function" :
163+ with self .fake_mode , enable_python_dispatcher ():
164+ fake = node .target (* fake_args , ** fake_kwargs )
165+ elif op == "call_method" :
166+ with self .fake_mode , enable_python_dispatcher ():
167+ fake = getattr (fake_args [0 ], node .target )(* fake_args [1 :],
168+ ** fake_kwargs )
169+ elif op == "call_module" :
170+ module = fetch_attr (node .target )
171+ with torch ._subclasses .fake_tensor .FakeCopyMode (self .fake_mode ):
172+ fake_module = wrap_fake_exception (lambda : copy .deepcopy (module ))
173+ with self .fake_mode , enable_python_dispatcher ():
174+ fake = fake_module (* fake_args , ** fake_kwargs )
175+ else :
176+ raise RuntimeError (f"Unknown target: { node .target } " )
177+ node .meta ["fake" ] = fake
178+
53179 def create_node (
54180 self ,
55181 kind : str ,
@@ -62,6 +188,9 @@ def create_node(
62188 self .mark_written_fn ()
63189 result_node = self .result_graph .create_node (kind , target , args , kwargs ,
64190 name , type_expr )
191+ if self .dynamic_shape :
192+ if kind not in ("placeholder" , "output" ):
193+ self .infer_fake_value (result_node )
65194 return result_node
66195
67196 def create_input (
@@ -73,11 +202,32 @@ def create_input(
73202 name : str ,
74203 type_expr : Optional [Any ] = None ,
75204 ) -> torch .fx .Node :
76- fake_tensor = self .fake_mode .from_tensor (value , static_shapes = True )
205+ fake_tensor = self .fake_mode .from_tensor (
206+ value , static_shapes = not self .dynamic_shape )
77207 self .mark_written_fn ()
78208 self .example_inputs .append ((fake_tensor , name ))
79- return self .create_node ("placeholder" , target , args , kwargs , name ,
209+ node = self .create_node ("placeholder" , target , args , kwargs , name ,
80210 type_expr )
211+ node .meta ["fake" ] = fake_tensor
212+ return node
213+
214+ def create_sym_input (
215+ self ,
216+ value : ScalarType ,
217+ target : torch .fx .node .Target ,
218+ args : Tuple [Any , ...],
219+ kwargs : Dict [str , Any ],
220+ name : str ,
221+ type_expr : Optional [Any ] = None ,
222+ ) -> torch .fx .Node :
223+ symbol = self .fake_mode .shape_env .create_symbol (value , Source ())
224+ fake = self .fake_mode .shape_env .create_symintnode (symbol , hint = value )
225+ self .mark_written_fn ()
226+ self .example_inputs .append ((fake , name ))
227+ node = self .create_node ("placeholder" , target , args , kwargs , name ,
228+ type_expr )
229+ node .meta ["fake" ] = fake
230+ return node
81231
82232 def set_output_nodes (self , output_nodes : list [torch .fx .Node ]) -> None :
83233 for node in self .result_graph .nodes :
@@ -90,15 +240,110 @@ def compile(
90240 model = torch .fx .GraphModule (self .root , self .result_graph )
91241 model .recompile ()
92242 with NO_LD_PRELOAD_CTX ():
93- compiled_fn = backend_compile (
94- model , [x [0 ].contiguous () for x in self .example_inputs ])
243+ compiled_fn = backend_compile (model , [
244+ x [0 ].contiguous () if isinstance (x [0 ], torch .Tensor ) else x [0 ]
245+ for x in self .example_inputs
246+ ])
95247 assert callable (compiled_fn )
248+ if self .fake_mode .shape_env is not None :
249+ print ("shape_env guards" , self .fake_mode .shape_env .format_guards ())
96250 # TODO: add backend compiler
97251 return compiled_fn
98252
99253 def get_inputs (self ) -> list [torch .fx .Node ]:
100254 return [x for x in self .result_graph .nodes if x .op == "placeholder" ]
101255
256+ def make_shape_env_guard (self , codegen : GuardFnCodegen ) -> None :
257+ fake_inputs : list [torch .FakeTensor ] = []
258+ poses : list [StorePos ] = []
259+ for node in self .result_graph .nodes :
260+ if node .op == "placeholder" :
261+ fake = node .meta ["fake" ]
262+ fake_inputs .append (fake )
263+ var = node .meta ["var" ]
264+ assert isinstance (var , (vs .TensorVar , vs .ScalarVar ))
265+ pos = var .extract_code_at_start [0 ]
266+ poses .append (pos )
267+ self .produce_guards (fake_inputs , poses , codegen )
268+
269+ # modified from torch produce_guards
270+ def produce_guards (self , placeholders : list [Any ], sources : list [StorePos ],
271+ codegen : GuardFnCodegen ) -> None :
272+ import math
273+ import operator
274+ SYMPY_INTERP = {
275+ 'Eq' : operator .eq ,
276+ 'Ne' : operator .ne ,
277+ 'Gt' : operator .gt ,
278+ 'Lt' : operator .lt ,
279+ 'Le' : operator .le ,
280+ 'Ge' : operator .ge ,
281+ 'Min' : min ,
282+ 'Max' : max ,
283+ 'Mod' : operator .mod ,
284+ 'FloorDiv' : operator .floordiv ,
285+ 'TrueDiv' : operator .truediv ,
286+ 'floor' : math .floor ,
287+ 'ceiling' : math .ceil ,
288+ }
289+ for k , v in SYMPY_INTERP .items ():
290+ codegen .add_obj (v , k , force = True )
291+ input_guards = []
292+ symbol_to_source = collections .defaultdict (list )
293+
294+ def track_symint (source : StorePos , val : Any ) -> None :
295+ if isinstance (val , SymInt ):
296+ s = val .node .expr
297+
298+ if isinstance (s , sympy .Symbol ):
299+ symbol_to_source [s ].append (source )
300+ elif isinstance (- s , sympy .Symbol ):
301+ symbol_to_source [- s ].append (StoreNegate (source ))
302+
303+ input_guards .append ((source , s ))
304+ else :
305+ input_guards .append ((source , sympy .Integer (val )))
306+
307+ for t , source in zip (placeholders , sources ):
308+ assert isinstance (source , StorePos )
309+ if t is None :
310+ continue
311+ if isinstance (t , SymInt ):
312+ track_symint (source , t )
313+ continue
314+ assert isinstance (t , torch .Tensor )
315+ for i , s in enumerate (t .size ()):
316+ track_symint (
317+ StoreInIndex (StoreInAttr (source , 0 , 'size()' ), 0 , i ), s )
318+
319+ for source , expr in input_guards :
320+ # Small optimization
321+ if (isinstance (expr , Symbol ) and expr in symbol_to_source and
322+ source == symbol_to_source [expr ][0 ]):
323+ continue
324+ sexpr = ShapeGuardPrinter (symbol_to_source ).doprint (expr )
325+ codegen .add_check (f"{ source } == { sexpr } " )
326+
327+ for g , tb in self .fake_mode .shape_env .guards :
328+ print ("guard" , g )
329+ if self .fake_mode .shape_env ._maybe_evaluate_static (g ) is not None :
330+ print ("maybe static" )
331+ continue
332+ print ("before simplify" , g )
333+ g = self .fake_mode .shape_env .simplify (g )
334+ print ("after simplify" , g )
335+ try :
336+ codegen .add_check (
337+ ShapeGuardPrinter (symbol_to_source ).doprint (g ))
338+ except Exception :
339+ print (f"Failing guard allocated at: \n { tb } " )
340+ raise
341+
342+ for sources in symbol_to_source .values ():
343+ assert sources
344+ codegen .add_check (f"{ sources [0 ]} != 0" )
345+ codegen .add_check (f"{ sources [0 ]} != 1" )
346+
102347
103348frame_root : dict [int , torch .nn .Module ] = {}
104349
@@ -127,4 +372,4 @@ def is_leaf_module(m: torch.nn.Module) -> bool:
127372
128373def reset () -> None :
129374 global frame_root
130- frame_root = {}
375+ frame_root = {}
0 commit comments