Skip to content

Commit da71124

Browse files
authored
dynamic shape support (apache#22)
2 parents 6559665 + 758b670 commit da71124

18 files changed

+1128
-70
lines changed

.mypy.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ exclude = (?x)(
44
)
55
strict = True
66
[mypy-torch.*]
7+
follow_imports = skip
8+
[mypy-sympy.*]
79
follow_imports = skip

frontend/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
CONFIG = {
44
"backend": "inductor", # Union[str, Callable[..., Any]]
55
"debug": True,
6+
"dynshape": False,
67
}
78

89

frontend/fx_graph.py

Lines changed: 252 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
from typing import Any, Callable, Dict, Optional, Tuple, Union
2+
from functools import partial
3+
import copy
4+
import collections
25
import torch
36
import torch.fx
7+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
8+
from torch._guards import Source
49
import torch._inductor.compile_fx
510
import torch._dynamo.backends.torchxla
6-
from .utils import NO_LD_PRELOAD_CTX
11+
import torch.fx.immutable_collections as fx_immutable
12+
from torch._dispatch.python import enable_python_dispatcher
13+
from torch import SymInt, SymFloat, SymBool
14+
from torch.fx.experimental.symbolic_shapes import Symbol
15+
from sympy.printing.str import StrPrinter
16+
import sympy
17+
from .no_preload import NO_LD_PRELOAD_CTX
718
from . import config
19+
from .utils import ScalarType
20+
from .pycode_generator import GuardFnCodegen
21+
from .store_pos import StorePos, StoreNegate, StoreInAttr, StoreInIndex
22+
from . import variables as vs
823

924
BaseArgumentTypes = Union[
1025
str,
@@ -35,6 +50,48 @@ def backend_compile(gm: torch.fx.GraphModule,
3550
raise RuntimeError(f"Unknown backend: {backend}")
3651

3752

53+
def guard_check_shapeenv(inputs: list[torch.Tensor], fake_inputs: list[Any],
54+
shape_env: ShapeEnv) -> bool:
55+
symbol2value: dict[Symbol, Any] = {}
56+
for fake_input, input in zip(fake_inputs, inputs):
57+
if isinstance(fake_input, torch._subclasses.FakeTensor):
58+
assert isinstance(input, torch.Tensor)
59+
if len(input.shape) != len(fake_input.shape):
60+
return False
61+
for symbol, value in zip(fake_input.shape, input.shape):
62+
expr = symbol.node.expr
63+
if expr in symbol2value:
64+
if symbol2value[expr] != value:
65+
print("false due to shape", fake_input.shape,
66+
input.shape)
67+
print("symbol2value", symbol2value[expr])
68+
return False
69+
else:
70+
symbol2value[expr] = value
71+
else:
72+
raise NotImplementedError
73+
for guard in shape_env.guards:
74+
val = guard.expr.subs(symbol2value)
75+
if not (val is sympy.true):
76+
print("guard fail", guard, symbol2value)
77+
return False
78+
return True
79+
80+
81+
class ShapeGuardPrinter(StrPrinter): # type: ignore[misc]
82+
83+
def __init__(self, symbol_to_source: Dict[Symbol, list[StorePos]]):
84+
super().__init__()
85+
self.symbol_to_source = symbol_to_source
86+
87+
def _print_Symbol(self, expr: Symbol) -> str:
88+
assert isinstance(expr, Symbol), str(type(expr))
89+
assert expr in self.symbol_to_source, (
90+
f"{expr} (could be from {[s.name() for s in expr.sources]}) "
91+
f"not in {self.symbol_to_source}")
92+
return str(self.symbol_to_source[expr][0])
93+
94+
3895
class FxGraph:
3996
root: torch.nn.Module
4097
result_graph: torch.fx.Graph
@@ -47,9 +104,78 @@ def __init__(self, root: torch.nn.Module,
47104
self.root = root
48105
self.result_graph = torch.fx.Graph(root)
49106
self.mark_written_fn = mark_written_fn
50-
self.fake_mode = torch._subclasses.FakeTensorMode()
107+
self.dynamic_shape = config.get_config('dynshape')
108+
self.fake_mode = torch._subclasses.FakeTensorMode(
109+
shape_env=ShapeEnv() if self.dynamic_shape else None,
110+
# allow_non_fake_inputs=True
111+
)
51112
self.example_inputs = []
52113

114+
def infer_fake_value(self, node: torch.fx.Node) -> None:
115+
116+
def wrap_fake_exception(fn: Callable[[], Any]) -> Any:
117+
try:
118+
return fn()
119+
except torch._subclasses.UnsupportedFakeTensorException as e:
120+
msg = f"Unsupported: {e.reason} with fake tensor propagation."
121+
raise NotImplementedError(msg) from e
122+
123+
def as_fake_args_kwargs(
124+
args: Tuple[Any, ...],
125+
kwargs: Dict[str, Any]) -> Tuple[Any, Dict[str, Any]]:
126+
127+
def as_fake(arg: Any) -> Any:
128+
if isinstance(arg, (tuple, list)):
129+
return fx_immutable.immutable_list(
130+
[as_fake(x) for x in arg])
131+
if isinstance(arg, slice):
132+
return slice(as_fake(arg.start), as_fake(arg.stop),
133+
as_fake(arg.step))
134+
if isinstance(arg, torch.fx.Node):
135+
return arg.meta["fake"]
136+
else:
137+
return arg
138+
139+
fake_args = tuple(as_fake(arg) for arg in args)
140+
fake_kwargs = {k: as_fake(v) for k, v in kwargs.items()}
141+
return fake_args, fake_kwargs
142+
143+
def fetch_attr(target: str) -> Any:
144+
target_atoms = target.split('.')
145+
attr_itr = self.root
146+
for i, atom in enumerate(target_atoms):
147+
if not hasattr(attr_itr, atom):
148+
raise RuntimeError(
149+
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
150+
)
151+
attr_itr = getattr(attr_itr, atom)
152+
return attr_itr
153+
154+
fake_args, fake_kwargs = as_fake_args_kwargs(node.args, node.kwargs)
155+
fake: Any = None
156+
op = node.op
157+
assert op not in ("placeholder", "output")
158+
if op == "get_attr":
159+
with self.fake_mode, enable_python_dispatcher():
160+
param = fetch_attr(node.target)
161+
fake = self.fake_mode.from_tensor(param, static_shapes=True)
162+
elif op == "call_function":
163+
with self.fake_mode, enable_python_dispatcher():
164+
fake = node.target(*fake_args, **fake_kwargs)
165+
elif op == "call_method":
166+
with self.fake_mode, enable_python_dispatcher():
167+
fake = getattr(fake_args[0], node.target)(*fake_args[1:],
168+
**fake_kwargs)
169+
elif op == "call_module":
170+
module = fetch_attr(node.target)
171+
with torch._subclasses.fake_tensor.FakeCopyMode(self.fake_mode):
172+
fake_module = wrap_fake_exception(lambda: copy.deepcopy(module))
173+
with self.fake_mode, enable_python_dispatcher():
174+
fake = fake_module(*fake_args, **fake_kwargs)
175+
else:
176+
raise RuntimeError(f"Unknown target: {node.target}")
177+
node.meta["fake"] = fake
178+
53179
def create_node(
54180
self,
55181
kind: str,
@@ -62,6 +188,9 @@ def create_node(
62188
self.mark_written_fn()
63189
result_node = self.result_graph.create_node(kind, target, args, kwargs,
64190
name, type_expr)
191+
if self.dynamic_shape:
192+
if kind not in ("placeholder", "output"):
193+
self.infer_fake_value(result_node)
65194
return result_node
66195

67196
def create_input(
@@ -73,11 +202,32 @@ def create_input(
73202
name: str,
74203
type_expr: Optional[Any] = None,
75204
) -> torch.fx.Node:
76-
fake_tensor = self.fake_mode.from_tensor(value, static_shapes=True)
205+
fake_tensor = self.fake_mode.from_tensor(
206+
value, static_shapes=not self.dynamic_shape)
77207
self.mark_written_fn()
78208
self.example_inputs.append((fake_tensor, name))
79-
return self.create_node("placeholder", target, args, kwargs, name,
209+
node = self.create_node("placeholder", target, args, kwargs, name,
80210
type_expr)
211+
node.meta["fake"] = fake_tensor
212+
return node
213+
214+
def create_sym_input(
215+
self,
216+
value: ScalarType,
217+
target: torch.fx.node.Target,
218+
args: Tuple[Any, ...],
219+
kwargs: Dict[str, Any],
220+
name: str,
221+
type_expr: Optional[Any] = None,
222+
) -> torch.fx.Node:
223+
symbol = self.fake_mode.shape_env.create_symbol(value, Source())
224+
fake = self.fake_mode.shape_env.create_symintnode(symbol, hint=value)
225+
self.mark_written_fn()
226+
self.example_inputs.append((fake, name))
227+
node = self.create_node("placeholder", target, args, kwargs, name,
228+
type_expr)
229+
node.meta["fake"] = fake
230+
return node
81231

82232
def set_output_nodes(self, output_nodes: list[torch.fx.Node]) -> None:
83233
for node in self.result_graph.nodes:
@@ -90,15 +240,110 @@ def compile(
90240
model = torch.fx.GraphModule(self.root, self.result_graph)
91241
model.recompile()
92242
with NO_LD_PRELOAD_CTX():
93-
compiled_fn = backend_compile(
94-
model, [x[0].contiguous() for x in self.example_inputs])
243+
compiled_fn = backend_compile(model, [
244+
x[0].contiguous() if isinstance(x[0], torch.Tensor) else x[0]
245+
for x in self.example_inputs
246+
])
95247
assert callable(compiled_fn)
248+
if self.fake_mode.shape_env is not None:
249+
print("shape_env guards", self.fake_mode.shape_env.format_guards())
96250
# TODO: add backend compiler
97251
return compiled_fn
98252

99253
def get_inputs(self) -> list[torch.fx.Node]:
100254
return [x for x in self.result_graph.nodes if x.op == "placeholder"]
101255

256+
def make_shape_env_guard(self, codegen: GuardFnCodegen) -> None:
257+
fake_inputs: list[torch.FakeTensor] = []
258+
poses: list[StorePos] = []
259+
for node in self.result_graph.nodes:
260+
if node.op == "placeholder":
261+
fake = node.meta["fake"]
262+
fake_inputs.append(fake)
263+
var = node.meta["var"]
264+
assert isinstance(var, (vs.TensorVar, vs.ScalarVar))
265+
pos = var.extract_code_at_start[0]
266+
poses.append(pos)
267+
self.produce_guards(fake_inputs, poses, codegen)
268+
269+
# modified from torch produce_guards
270+
def produce_guards(self, placeholders: list[Any], sources: list[StorePos],
271+
codegen: GuardFnCodegen) -> None:
272+
import math
273+
import operator
274+
SYMPY_INTERP = {
275+
'Eq': operator.eq,
276+
'Ne': operator.ne,
277+
'Gt': operator.gt,
278+
'Lt': operator.lt,
279+
'Le': operator.le,
280+
'Ge': operator.ge,
281+
'Min': min,
282+
'Max': max,
283+
'Mod': operator.mod,
284+
'FloorDiv': operator.floordiv,
285+
'TrueDiv': operator.truediv,
286+
'floor': math.floor,
287+
'ceiling': math.ceil,
288+
}
289+
for k, v in SYMPY_INTERP.items():
290+
codegen.add_obj(v, k, force=True)
291+
input_guards = []
292+
symbol_to_source = collections.defaultdict(list)
293+
294+
def track_symint(source: StorePos, val: Any) -> None:
295+
if isinstance(val, SymInt):
296+
s = val.node.expr
297+
298+
if isinstance(s, sympy.Symbol):
299+
symbol_to_source[s].append(source)
300+
elif isinstance(-s, sympy.Symbol):
301+
symbol_to_source[-s].append(StoreNegate(source))
302+
303+
input_guards.append((source, s))
304+
else:
305+
input_guards.append((source, sympy.Integer(val)))
306+
307+
for t, source in zip(placeholders, sources):
308+
assert isinstance(source, StorePos)
309+
if t is None:
310+
continue
311+
if isinstance(t, SymInt):
312+
track_symint(source, t)
313+
continue
314+
assert isinstance(t, torch.Tensor)
315+
for i, s in enumerate(t.size()):
316+
track_symint(
317+
StoreInIndex(StoreInAttr(source, 0, 'size()'), 0, i), s)
318+
319+
for source, expr in input_guards:
320+
# Small optimization
321+
if (isinstance(expr, Symbol) and expr in symbol_to_source and
322+
source == symbol_to_source[expr][0]):
323+
continue
324+
sexpr = ShapeGuardPrinter(symbol_to_source).doprint(expr)
325+
codegen.add_check(f"{source} == {sexpr}")
326+
327+
for g, tb in self.fake_mode.shape_env.guards:
328+
print("guard", g)
329+
if self.fake_mode.shape_env._maybe_evaluate_static(g) is not None:
330+
print("maybe static")
331+
continue
332+
print("before simplify", g)
333+
g = self.fake_mode.shape_env.simplify(g)
334+
print("after simplify", g)
335+
try:
336+
codegen.add_check(
337+
ShapeGuardPrinter(symbol_to_source).doprint(g))
338+
except Exception:
339+
print(f"Failing guard allocated at: \n{tb}")
340+
raise
341+
342+
for sources in symbol_to_source.values():
343+
assert sources
344+
codegen.add_check(f"{sources[0]} != 0")
345+
codegen.add_check(f"{sources[0]} != 1")
346+
102347

103348
frame_root: dict[int, torch.nn.Module] = {}
104349

@@ -127,4 +372,4 @@ def is_leaf_module(m: torch.nn.Module) -> bool:
127372

128373
def reset() -> None:
129374
global frame_root
130-
frame_root = {}
375+
frame_root = {}

0 commit comments

Comments
 (0)