11"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+
23from collections import defaultdict
34from pathlib import Path
45from typing import Union
1213class LineProfilerDecoratorAdder (cst .CSTTransformer ):
1314 """Transformer that adds a decorator to a function with a specific qualified name."""
1415
15- #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
16+ # TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
1617 def __init__ (self , qualified_name : str , decorator_name : str ):
1718 """Initialize the transformer.
1819
@@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4546 function_name = original_node .name .value
4647
4748 # Check if the current context path matches our target qualified name
48- if self .context_stack == self .qualified_name_parts :
49+ if self .context_stack == self .qualified_name_parts :
4950 # Check if the decorator is already present
5051 has_decorator = any (
51- self ._is_target_decorator (decorator .decorator )
52- for decorator in original_node .decorators
52+ self ._is_target_decorator (decorator .decorator ) for decorator in original_node .decorators
5353 )
5454
5555 # Only add the decorator if it's not already there
5656 if not has_decorator :
57- new_decorator = cst .Decorator (
58- decorator = cst .Name (value = self .decorator_name )
59- )
57+ new_decorator = cst .Decorator (decorator = cst .Name (value = self .decorator_name ))
6058
6159 # Add our new decorator to the existing decorators
6260 updated_decorators = [new_decorator ] + list (updated_node .decorators )
63- updated_node = updated_node .with_changes (
64- decorators = tuple (updated_decorators )
65- )
61+ updated_node = updated_node .with_changes (decorators = tuple (updated_decorators ))
6662
6763 # Pop the context when we leave a function
6864 self .context_stack .pop ()
@@ -76,22 +72,21 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
7672 return decorator_node .func .value == self .decorator_name
7773 return False
7874
75+
7976class ProfileEnableTransformer (cst .CSTTransformer ):
80- def __init__ (self ,filename ):
81- # Flag to track if we found the import statement
82- self .found_import = False
83- # Track indentation of the import statement
84- self .import_indentation = None
85- self .filename = filename
77+ def __init__ (self , line_profile_output_file : str ):
78+ self .line_profile_output_file = line_profile_output_file
8679
8780 def leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) -> cst .ImportFrom :
8881 # Check if this is the line profiler import statement
89- if (isinstance (original_node .module , cst .Name ) and
90- original_node .module .value == "line_profiler" and
91- any (name .name .value == "profile" and
92- (not name .asname or name .asname .name .value == "codeflash_line_profile" )
93- for name in original_node .names )):
94-
82+ if (
83+ isinstance (original_node .module , cst .Name )
84+ and original_node .module .value == "line_profiler"
85+ and any (
86+ name .name .value == "profile" and (not name .asname or name .asname .name .value == "codeflash_line_profile" )
87+ for name in original_node .names
88+ )
89+ ):
9590 self .found_import = True
9691 # Get the indentation from the original node
9792 if hasattr (original_node , "leading_lines" ):
@@ -113,28 +108,39 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
113108 if isinstance (stmt , cst .SimpleStatementLine ):
114109 for small_stmt in stmt .body :
115110 if isinstance (small_stmt , cst .ImportFrom ):
116- if (isinstance (small_stmt .module , cst .Name ) and
117- small_stmt .module .value == "line_profiler" and
118- any (name .name .value == "profile" and
119- (not name .asname or name .asname .name .value == "codeflash_line_profile" )
120- for name in small_stmt .names )):
111+ if (
112+ isinstance (small_stmt .module , cst .Name )
113+ and small_stmt .module .value == "line_profiler"
114+ and any (
115+ name .name .value == "profile"
116+ and (not name .asname or name .asname .name .value == "codeflash_line_profile" )
117+ for name in small_stmt .names
118+ )
119+ ):
121120 import_index = i
122121 break
123122 if import_index is not None :
124123 break
125124
126125 if import_index is not None :
127126 # Create the new enable statement to insert after the import
128- enable_statement = cst .parse_statement (
129- f"codeflash_line_profile.enable(output_prefix='{ self .filename } ')"
130- )
127+ enable_statement = cst .parse_statement (f"codeflash_line_profile.enable(output_prefix='{ self .filename } ')" )
131128
132129 # Insert the new statement after the import statement
133130 new_body .insert (import_index + 1 , enable_statement )
134131
135132 # Create a new module with the updated body
136133 return updated_node .with_changes (body = new_body )
137134
135+ def __init__ (self , line_profile_output_file : str ):
136+ self .line_profile_output_file = line_profile_output_file
137+
138+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
139+ # This is a simplified example of the transformation logic
140+ new_decorator = cst .Decorator (decorator = cst .Name (value = "codeflash_line_profile" ))
141+ return updated_node .with_changes (decorators = [* updated_node .decorators , new_decorator ])
142+
143+
138144def add_decorator_to_qualified_function (module , qualified_name , decorator_name ):
139145 """Add a decorator to a function with the exact qualified name in the source code.
140146
@@ -156,9 +162,20 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
156162 # Convert the modified CST back to source code
157163 return modified_module
158164
165+
159166def add_profile_enable (original_code : str , line_profile_output_file : str ) -> str :
160- # TODO modify by using a libcst transformer
167+ # Avoid unnecessary transformations
168+ if not original_code .strip ():
169+ return original_code
170+
171+ # Parse the module only once
161172 module = cst .parse_module (original_code )
173+
174+ # If we can determine whether the transformer needs to be applied, we can shortcut
175+ if not has_transformable_content (module ):
176+ return original_code
177+
178+ # Apply transformer optimally
162179 transformer = ProfileEnableTransformer (line_profile_output_file )
163180 modified_module = module .visit (transformer )
164181 return modified_module .code
@@ -178,9 +195,7 @@ def leave_Module(self, original_node, updated_node):
178195 import_node = cst .parse_statement (self .import_statement )
179196
180197 # Add the import to the module's body
181- return updated_node .with_changes (
182- body = [import_node ] + list (updated_node .body )
183- )
198+ return updated_node .with_changes (body = [import_node ] + list (updated_node .body ))
184199
185200 def visit_ImportFrom (self , node ):
186201 # Check if the profile is already imported from line_profiler
@@ -192,15 +207,15 @@ def visit_ImportFrom(self, node):
192207
193208def add_decorator_imports (function_to_optimize , code_context ):
194209 """Adds a profile decorator to a function in a Python file and all its helper functions."""
195- #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
196- #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
210+ # self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
211+ # grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
197212 file_paths = defaultdict (list )
198213 line_profile_output_file = get_run_tmp_file (Path ("baseline_lprof" ))
199214 file_paths [function_to_optimize .file_path ].append (function_to_optimize .qualified_name )
200215 for elem in code_context .helper_functions :
201216 file_paths [elem .file_path ].append (elem .qualified_name )
202- for file_path ,fns_present in file_paths .items ():
203- #open file
217+ for file_path , fns_present in file_paths .items ():
218+ # open file
204219 file_contents = file_path .read_text ("utf-8" )
205220 # parse to cst
206221 module_node = cst .parse_module (file_contents )
@@ -216,8 +231,32 @@ def add_decorator_imports(function_to_optimize, code_context):
216231 # write to file
217232 with open (file_path , "w" , encoding = "utf-8" ) as file :
218233 file .write (modified_code )
219- #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
234+ # Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
220235 file_contents = function_to_optimize .file_path .read_text ("utf-8" )
221- modified_code = add_profile_enable (file_contents ,str (line_profile_output_file ))
222- function_to_optimize .file_path .write_text (modified_code ,"utf-8" )
236+ modified_code = add_profile_enable (file_contents , str (line_profile_output_file ))
237+ function_to_optimize .file_path .write_text (modified_code , "utf-8" )
223238 return line_profile_output_file
239+
240+
241+ def has_transformable_content (module ) -> bool :
242+ """Function to quickly check if the module has content that needs to be transformed.
243+ This can help in reducing unnecessary transformations.
244+ """
245+
246+ # A simple check to see if the transformer is needed (can be more complex as required)
247+ # For example, checking if the profile decorators are already present
248+ class CheckVisitor (cst .CSTVisitor ):
249+ def __init__ (self ):
250+ self .has_target = False
251+
252+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
253+ for deco in node .decorators :
254+ if deco .decorator .value == "codeflash_line_profile" :
255+ self .has_target = True
256+ return False # Stop visiting further
257+
258+ return True
259+
260+ visitor = CheckVisitor ()
261+ module .visit (visitor )
262+ return visitor .has_target
0 commit comments