1
+ from os .path import commonprefix
1
2
from typing import Sequence
2
3
3
4
import libcst as cst
5
+ from libcst .codemod .visitors import ImportItem
6
+
4
7
from ...common import TorchVisitor
5
8
6
9
@@ -32,7 +35,32 @@ def visit_Call(self, node):
32
35
public_name = self .ALIASES [qualified_name ]
33
36
error_code = self .ERROR_CODE [0 ]
34
37
message = f"Use of non-public function `{ qualified_name } `, please use `{ public_name } ` instead" # noqa: E501
35
- self .add_violation (node , error_code = error_code , message = message )
38
+
39
+ call_name = cst .helpers .get_full_name_for_node (node )
40
+ replacement = None
41
+ if not public_name .endswith (call_name ):
42
+ # We need to change the call name as it's not in the public name.
43
+ # Get the new call name on the same hierarchical level.
44
+ new_call_name = public_name .removeprefix (
45
+ commonprefix ([qualified_name .removesuffix (call_name ), public_name ])
46
+ )
47
+ new_module_name = public_name .removesuffix (new_call_name ).removesuffix (
48
+ "."
49
+ )
50
+ if new_module_name :
51
+ self .needed_imports .add (
52
+ ImportItem (
53
+ module_name = new_module_name ,
54
+ obj_name = new_call_name .split ("." )[0 ],
55
+ )
56
+ )
57
+ replacement = node .with_changes (
58
+ func = cst .parse_expression (new_call_name )
59
+ )
60
+
61
+ self .add_violation (
62
+ node , error_code = error_code , message = message , replacement = replacement
63
+ )
36
64
37
65
def visit_ImportFrom (self , node : cst .ImportFrom ) -> None :
38
66
if node .module is None :
@@ -48,4 +76,20 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
48
76
public_name = self .ALIASES [qualified_name ]
49
77
error_code = self .ERROR_CODE [1 ]
50
78
message = f"Import of non-public function `{ qualified_name } `, please use `{ public_name } ` instead" # noqa: E501
51
- self .add_violation (node , error_code = error_code , message = message )
79
+
80
+ new_module = "." .join (public_name .split ("." )[:- 1 ])
81
+ new_name = public_name .split ("." )[- 1 ]
82
+ # Replace only if the import statement has no other names
83
+ if len (node .names ) == 1 :
84
+ replacement = cst .ImportFrom (
85
+ module = cst .parse_expression (new_module ), # type: ignore[arg-type] # noqa: E501
86
+ names = [cst .ImportAlias (name = cst .Name (new_name ))],
87
+ )
88
+ else :
89
+ replacement = None
90
+ self .add_violation (
91
+ node ,
92
+ error_code = error_code ,
93
+ message = message ,
94
+ replacement = replacement ,
95
+ )
0 commit comments