17
17
DictComp ,
18
18
)
19
19
from torch ._utils_internal import get_source_lines_and_file
20
-
20
+ from torch . jit . _monkeytype_config import monkeytype_trace , get_qualified_name
21
21
from torch ._jit_internal import SourceContext , should_drop , is_static_fn
22
22
import torch .jit .annotations
23
23
@@ -289,8 +289,15 @@ def _forward(self):
289
289
# Replace potentially unsupported type annotations by "Any"
290
290
arg .annotation = unused_def .args .args [0 ].annotation
291
291
292
- return build_def (ctx , fn_def , type_line , def_name , self_name = self_name )
292
+ # If MonkeyType is installed, get all the consolidated type traces
293
+ # for the arguments from type_trace_db
294
+ type_trace_db = torch .jit ._script ._get_type_trace_db ()
295
+ pdt_arg_types = None
296
+ if monkeytype_trace :
297
+ qualname = get_qualified_name (fn )
298
+ pdt_arg_types = type_trace_db .get_args_types (qualname )
293
299
300
+ return build_def (ctx , fn_def , type_line , def_name , self_name = self_name , pdt_arg_types = pdt_arg_types )
294
301
295
302
class Builder (object ):
296
303
def __call__ (self , ctx , node ):
@@ -306,15 +313,17 @@ def build_class_def(ctx, py_def, methods, properties, self_name, assigns):
306
313
return ClassDef (Ident (r , self_name ), [Stmt (method ) for method in methods ], properties , assigns )
307
314
308
315
309
- def build_def (ctx , py_def , type_line , def_name , self_name = None ):
316
+ def build_def (ctx , py_def , type_line , def_name , self_name = None , pdt_arg_types = None ):
310
317
body = py_def .body
311
318
r = ctx .make_range (py_def .lineno + len (py_def .decorator_list ),
312
319
py_def .col_offset ,
313
320
py_def .col_offset + len ("def" ))
314
- param_list = build_param_list (ctx , py_def .args , self_name )
321
+
322
+ param_list = build_param_list (ctx , py_def .args , self_name , pdt_arg_types )
315
323
return_type = None
316
324
if getattr (py_def , 'returns' , None ) is not None :
317
325
return_type = build_expr (ctx , py_def .returns )
326
+
318
327
decl = Decl (r , param_list , return_type )
319
328
is_method = self_name is not None
320
329
if type_line is not None :
@@ -330,7 +339,7 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None):
330
339
"or use keyword-only arguments with defaults" )
331
340
332
341
333
- def build_param_list (ctx , py_args , self_name ):
342
+ def build_param_list (ctx , py_args , self_name , pdt_arg_types = None ):
334
343
if py_args .kwarg is not None :
335
344
expr = py_args .kwarg
336
345
ctx_range = ctx .make_range (expr .lineno , expr .col_offset - 1 , expr .col_offset + len (expr .arg ))
@@ -346,17 +355,27 @@ def build_param_list(ctx, py_args, self_name):
346
355
if arg is not None :
347
356
ctx_range = build_expr (ctx , arg ).range ()
348
357
raise NotSupportedError (ctx_range , _vararg_kwarg_err )
349
- result = [build_param (ctx , arg , self_name , False ) for arg in py_args .args ]
350
- result += [build_param (ctx , arg , self_name , True ) for arg in py_args .kwonlyargs ]
358
+
359
+ # List of Tuple of args and type as inferred by profile directed typing
360
+ arg_and_types = [(arg , next (iter (pdt_arg_types [arg .arg ])) if pdt_arg_types and bool (pdt_arg_types [arg .arg ]) else None )
361
+ for arg in py_args .args ]
362
+ arg_and_types_kwonlyargs = [(arg , next (iter (pdt_arg_types [arg .arg ])) if pdt_arg_types and bool (pdt_arg_types [arg .arg ])
363
+ else None ) for arg in py_args .kwonlyargs ]
364
+
365
+ result = [build_param (ctx , arg , self_name , kwarg_only = False , pdt_arg_type = arg_type ) for arg , arg_type in arg_and_types ]
366
+ result += [build_param (ctx , arg , self_name , kwarg_only = True , pdt_arg_type = arg_type )
367
+ for arg , arg_type in arg_and_types_kwonlyargs ]
351
368
return result
352
369
353
370
354
- def build_param (ctx , py_arg , self_name , kwarg_only ):
371
+ def build_param (ctx , py_arg , self_name , kwarg_only , pdt_arg_type = None ):
355
372
# NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
356
373
name = py_arg .arg
357
374
r = ctx .make_range (py_arg .lineno , py_arg .col_offset , py_arg .col_offset + len (name ))
358
375
if getattr (py_arg , 'annotation' , None ) is not None :
359
376
annotation_expr = build_expr (ctx , py_arg .annotation )
377
+ elif pdt_arg_type :
378
+ annotation_expr = Var (Ident (r , pdt_arg_type ))
360
379
elif self_name is not None and name == 'self' :
361
380
annotation_expr = Var (Ident (r , self_name ))
362
381
else :
0 commit comments