Skip to content

Commit a72d6f8

Browse files
authored
Parity 1 (apache#35)
2 parents 62fdc54 + a8ed0ca commit a72d6f8

File tree

6 files changed

+21
-6
lines changed

6 files changed

+21
-6
lines changed

frontend/guard_tracker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def record_function(self,
286286
if func in (min, max):
287287
scalar = None
288288
node = None
289-
assert len(pargs) == 2
289+
# NOTE: when pargs < 2, it should be a dynamic operation
290+
assert len(pargs) <= 2
290291
for i, obj in enumerate(pargs):
291292
if isinstance(obj, (int, float)) and not dyn.contains(obj):
292293
scalar = obj
@@ -1548,7 +1549,9 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool:
15481549

15491550
def is_builtin_func(self, func: Callable[..., Any]) -> bool:
15501551
return func in (dict, tuple, set, list, hasattr, slice, range, len,
1551-
type, all, str.join, reversed, zip, iter, id, next)
1552+
type, all, str.join, reversed, zip, iter, id, next,
1553+
collections.OrderedDict, str.format, any, str,
1554+
str.split)
15521555

15531556
def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
15541557
print(dir(func))

frontend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def get_root_module(func: Callable[..., Any]) -> str:
156156
if module is None or 'torch.distributions' in module_str:
157157
return ""
158158
root_module = module_str.split('.')[0]
159+
#NOTE: special cases in torchvision module, need to check whether this module is safe to record in graph
160+
if hasattr(func, '__name__') and func.__name__ in (
161+
'pad', 'resize') and root_module == 'torchvision':
162+
return 'torch'
159163
return root_module
160164

161165

frontend/variables/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def make_var_from_value(
5454
extract_code_at_start: Optional[list[StorePos]] = None) -> Variable:
5555
if extract_code_at_start is None:
5656
extract_code_at_start = []
57+
if type(value) == np.ndarray and value.size == 1:
58+
return NumpyScalarVar.from_value(np.int64(value.tolist()),
59+
need_guard_check, helper_functions,
60+
fx_graph, extract_code_at_start)
5761
if type(value) in ty2var:
5862
return ty2var[type(value)].from_value(value, need_guard_check,
5963
helper_functions, fx_graph,

frontend/variables/dict_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
7676
items = []
7777
for key, j in zip(self.value.keys(), range(len(self.vars))):
7878
if isinstance(key, str):
79-
key_part = f"'{key}'"
79+
if "\n" not in key:
80+
key_part = f"'{key}'"
81+
else:
82+
key_part = f"'{repr(key)}'"
83+
key_part = key_part.strip("'")
8084
else:
8185
key_part = key
8286
item = f'{key_part}: {name_in_graph_fn}_{j}'

frontend/variables/list_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool,
114114
extract_code_at_start: list[StorePos]) -> None:
115115
super().__init__(need_guard_check, value, extract_code_at_start)
116116
self.value = value
117-
self.length = len(value)
117+
self.length = value.size
118118
self.vars = []
119119
self.obj_ids = []
120120
for i, obj in enumerate(value):

frontend/variables/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def make_guard_inner(self, codegen: "GuardFnCodegen",
238238
def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
239239
codegen: "GraphFnCodegen", in_return: bool,
240240
idx: int) -> None:
241-
codegen.output(name_in_graph_fn, store_pos, f"{self.device}", in_return,
242-
idx)
241+
codegen.output(name_in_graph_fn, store_pos, f"'{self.device}'",
242+
in_return, idx)
243243

244244
def as_fx_node(self) -> "NodeArgs":
245245
return self.device

0 commit comments

Comments
 (0)