Skip to content

Commit d4707e2

Browse files
nikithamalgifacebook-github-bot
authored andcommitted
Infer types (pytorch#56832)
Summary: Addresses: Infer argument types for functions in JIT Pull Request resolved: pytorch#56832 Reviewed By: pbelevich Differential Revision: D27979495 Pulled By: nikithamalgifb fbshipit-source-id: 82156a516c7f96cdd3f7a067d41cb210a6d13a51
1 parent e97c17a commit d4707e2

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

torch/jit/frontend.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
DictComp,
1818
)
1919
from torch._utils_internal import get_source_lines_and_file
20-
20+
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
2121
from torch._jit_internal import SourceContext, should_drop, is_static_fn
2222
import torch.jit.annotations
2323

@@ -289,8 +289,15 @@ def _forward(self):
289289
# Replace potentially unsupported type annotations by "Any"
290290
arg.annotation = unused_def.args.args[0].annotation
291291

292-
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
292+
# If MonkeyType is installed, get all the consolidated type traces
293+
# for the arguments from type_trace_db
294+
type_trace_db = torch.jit._script._get_type_trace_db()
295+
pdt_arg_types = None
296+
if monkeytype_trace:
297+
qualname = get_qualified_name(fn)
298+
pdt_arg_types = type_trace_db.get_args_types(qualname)
293299

300+
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
294301

295302
class Builder(object):
296303
def __call__(self, ctx, node):
@@ -306,15 +313,17 @@ def build_class_def(ctx, py_def, methods, properties, self_name, assigns):
306313
return ClassDef(Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns)
307314

308315

309-
def build_def(ctx, py_def, type_line, def_name, self_name=None):
316+
def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
310317
body = py_def.body
311318
r = ctx.make_range(py_def.lineno + len(py_def.decorator_list),
312319
py_def.col_offset,
313320
py_def.col_offset + len("def"))
314-
param_list = build_param_list(ctx, py_def.args, self_name)
321+
322+
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
315323
return_type = None
316324
if getattr(py_def, 'returns', None) is not None:
317325
return_type = build_expr(ctx, py_def.returns)
326+
318327
decl = Decl(r, param_list, return_type)
319328
is_method = self_name is not None
320329
if type_line is not None:
@@ -330,7 +339,7 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None):
330339
"or use keyword-only arguments with defaults")
331340

332341

333-
def build_param_list(ctx, py_args, self_name):
342+
def build_param_list(ctx, py_args, self_name, pdt_arg_types=None):
334343
if py_args.kwarg is not None:
335344
expr = py_args.kwarg
336345
ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg))
@@ -346,17 +355,27 @@ def build_param_list(ctx, py_args, self_name):
346355
if arg is not None:
347356
ctx_range = build_expr(ctx, arg).range()
348357
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
349-
result = [build_param(ctx, arg, self_name, False) for arg in py_args.args]
350-
result += [build_param(ctx, arg, self_name, True) for arg in py_args.kwonlyargs]
358+
359+
# List of Tuple of args and type as inferred by profile directed typing
360+
arg_and_types = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None)
361+
for arg in py_args.args]
362+
arg_and_types_kwonlyargs = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg])
363+
else None) for arg in py_args.kwonlyargs]
364+
365+
result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) for arg, arg_type in arg_and_types]
366+
result += [build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type)
367+
for arg, arg_type in arg_and_types_kwonlyargs]
351368
return result
352369

353370

354-
def build_param(ctx, py_arg, self_name, kwarg_only):
371+
def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None):
355372
# NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
356373
name = py_arg.arg
357374
r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
358375
if getattr(py_arg, 'annotation', None) is not None:
359376
annotation_expr = build_expr(ctx, py_arg.annotation)
377+
elif pdt_arg_type:
378+
annotation_expr = Var(Ident(r, pdt_arg_type))
360379
elif self_name is not None and name == 'self':
361380
annotation_expr = Var(Ident(r, self_name))
362381
else:

0 commit comments

Comments
 (0)