11import os
22from pathlib import Path
3+ from tempfile import TemporaryDirectory
34
45from codeflash .code_utils .line_profile_utils import add_decorator_imports
56from codeflash .discovery .functions_to_optimize import FunctionToOptimize
@@ -12,7 +13,6 @@ def test_add_decorator_imports_helper_in_class():
1213 code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/bubble_sort_classmethod.py" ).resolve ()
1314 tests_root = Path (__file__ ).parent .resolve () / "../code_to_optimize/tests/pytest/"
1415 project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
15- original_cwd = Path .cwd ()
1616 run_cwd = Path (__file__ ).parent .parent .resolve ()
1717 test_config = TestConfig (
1818 tests_root = tests_root ,
@@ -78,10 +78,10 @@ def helper(self, arr, j):
7878 )
7979
8080def test_add_decorator_imports_helper_in_nested_class ():
81+ #Need to invert the assert once the helper detection is fixed
8182 code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/bubble_sort_nested_classmethod.py" ).resolve ()
8283 tests_root = Path (__file__ ).parent .resolve () / "../code_to_optimize/tests/pytest/"
8384 project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
84- original_cwd = Path .cwd ()
8585 run_cwd = Path (__file__ ).parent .parent .resolve ()
8686 test_config = TestConfig (
8787 tests_root = tests_root ,
@@ -153,3 +153,170 @@ def helper(self, arr, j):
153153 func_optimizer .write_code_and_helpers (
154154 func_optimizer .function_to_optimize_source_code , original_helper_code , func_optimizer .function_to_optimize .file_path
155155 )
156+
157+ def test_add_decorator_imports_nodeps ():
158+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/bubble_sort.py" ).resolve ()
159+ tests_root = Path (__file__ ).parent .resolve () / "../code_to_optimize/tests/pytest/"
160+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
161+ run_cwd = Path (__file__ ).parent .parent .resolve ()
162+ test_config = TestConfig (
163+ tests_root = tests_root ,
164+ tests_project_rootdir = project_root_path ,
165+ project_root_path = project_root_path ,
166+ test_framework = "pytest" ,
167+ pytest_cmd = "pytest" ,
168+ )
169+ func = FunctionToOptimize (function_name = "sorter" , parents = [], file_path = code_path )
170+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
171+ os .chdir (run_cwd )
172+ #func_optimizer = pass
173+ try :
174+ ctx_result = func_optimizer .get_code_optimization_context ()
175+ code_context : CodeOptimizationContext = ctx_result .unwrap ()
176+ original_helper_code : dict [Path , str ] = {}
177+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
178+ for helper_function_path in helper_function_paths :
179+ with helper_function_path .open (encoding = "utf8" ) as f :
180+ helper_code = f .read ()
181+ original_helper_code [helper_function_path ] = helper_code
182+ line_profiler_output_file = add_decorator_imports (
183+ func_optimizer .function_to_optimize , code_context )
184+ expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
185+ codeflash_line_profile.enable(output_prefix='{ line_profiler_output_file } ')
186+
187+
188+ @codeflash_line_profile
189+ def sorter(arr):
190+ print("codeflash stdout: Sorting list")
191+ for i in range(len(arr)):
192+ for j in range(len(arr) - 1):
193+ if arr[j] > arr[j + 1]:
194+ temp = arr[j]
195+ arr[j] = arr[j + 1]
196+ arr[j + 1] = temp
197+ print(f"result: {{arr}}")
198+ return arr
199+ """
200+ assert code_path .read_text ("utf-8" ) == expected_code_main
201+ finally :
202+ func_optimizer .write_code_and_helpers (
203+ func_optimizer .function_to_optimize_source_code , original_helper_code , func_optimizer .function_to_optimize .file_path
204+ )
205+
206+ def test_add_decorator_imports_helper_outside ():
207+ code_path = (Path (__file__ ).parent .resolve () / "../code_to_optimize/bubble_sort_deps.py" ).resolve ()
208+ tests_root = Path (__file__ ).parent .resolve () / "../code_to_optimize/tests/pytest/"
209+ project_root_path = (Path (__file__ ).parent / ".." ).resolve ()
210+ run_cwd = Path (__file__ ).parent .parent .resolve ()
211+ test_config = TestConfig (
212+ tests_root = tests_root ,
213+ tests_project_rootdir = project_root_path ,
214+ project_root_path = project_root_path ,
215+ test_framework = "pytest" ,
216+ pytest_cmd = "pytest" ,
217+ )
218+ func = FunctionToOptimize (function_name = "sorter_deps" , parents = [], file_path = code_path )
219+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
220+ os .chdir (run_cwd )
221+ #func_optimizer = pass
222+ try :
223+ ctx_result = func_optimizer .get_code_optimization_context ()
224+ code_context : CodeOptimizationContext = ctx_result .unwrap ()
225+ original_helper_code : dict [Path , str ] = {}
226+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
227+ for helper_function_path in helper_function_paths :
228+ with helper_function_path .open (encoding = "utf8" ) as f :
229+ helper_code = f .read ()
230+ original_helper_code [helper_function_path ] = helper_code
231+ line_profiler_output_file = add_decorator_imports (
232+ func_optimizer .function_to_optimize , code_context )
233+ expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
234+ codeflash_line_profile.enable(output_prefix='{ line_profiler_output_file } ')
235+
236+ from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
237+ from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
238+
239+
240+ @codeflash_line_profile
241+ def sorter_deps(arr):
242+ for i in range(len(arr)):
243+ for j in range(len(arr) - 1):
244+ if dep1_comparer(arr, j):
245+ dep2_swap(arr, j)
246+ return arr
247+
248+ """
249+ expected_code_helper1 = """from line_profiler import profile as codeflash_line_profile
250+
251+
252+ @codeflash_line_profile
253+ def dep1_comparer(arr, j: int) -> bool:
254+ return arr[j] > arr[j + 1]
255+ """
256+ expected_code_helper2 = """from line_profiler import profile as codeflash_line_profile
257+
258+
259+ @codeflash_line_profile
260+ def dep2_swap(arr, j):
261+ temp = arr[j]
262+ arr[j] = arr[j + 1]
263+ arr[j + 1] = temp
264+ """
265+ assert code_path .read_text ("utf-8" ) == expected_code_main
266+ assert code_context .helper_functions [0 ].file_path .read_text ("utf-8" ) == expected_code_helper1
267+ assert code_context .helper_functions [1 ].file_path .read_text ("utf-8" ) == expected_code_helper2
268+ finally :
269+ func_optimizer .write_code_and_helpers (
270+ func_optimizer .function_to_optimize_source_code , original_helper_code , func_optimizer .function_to_optimize .file_path
271+ )
272+
273+ def test_add_decorator_imports_helper_in_dunder_class ():
274+ code_str = """def sorter(arr):
275+ ans = helper(arr)
276+ return ans
277+ class helper:
278+ def __init__(self, arr):
279+ return arr.sort()"""
280+ code_path = TemporaryDirectory ()
281+ code_write_path = Path (code_path .name ) / "dunder_class.py"
282+ code_write_path .write_text (code_str ,"utf-8" )
283+ tests_root = Path (__file__ ).parent .resolve () / "../code_to_optimize/tests/pytest/"
284+ project_root_path = Path (code_path .name )
285+ run_cwd = Path (__file__ ).parent .parent .resolve ()
286+ test_config = TestConfig (
287+ tests_root = tests_root ,
288+ tests_project_rootdir = project_root_path ,
289+ project_root_path = project_root_path ,
290+ test_framework = "pytest" ,
291+ pytest_cmd = "pytest" ,
292+ )
293+ func = FunctionToOptimize (function_name = "sorter" , parents = [], file_path = code_write_path )
294+ func_optimizer = FunctionOptimizer (function_to_optimize = func , test_cfg = test_config )
295+ os .chdir (run_cwd )
296+ #func_optimizer = pass
297+ try :
298+ ctx_result = func_optimizer .get_code_optimization_context ()
299+ code_context : CodeOptimizationContext = ctx_result .unwrap ()
300+ original_helper_code : dict [Path , str ] = {}
301+ helper_function_paths = {hf .file_path for hf in code_context .helper_functions }
302+ for helper_function_path in helper_function_paths :
303+ with helper_function_path .open (encoding = "utf8" ) as f :
304+ helper_code = f .read ()
305+ original_helper_code [helper_function_path ] = helper_code
306+ line_profiler_output_file = add_decorator_imports (
307+ func_optimizer .function_to_optimize , code_context )
308+ expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
309+ codeflash_line_profile.enable(output_prefix='{ line_profiler_output_file } ')
310+
311+
312+ @codeflash_line_profile
313+ def sorter(arr):
314+ ans = helper(arr)
315+ return ans
316+ class helper:
317+ def __init__(self, arr):
318+ return arr.sort()
319+ """
320+ assert code_write_path .read_text ("utf-8" ) == expected_code_main
321+ finally :
322+ pass
0 commit comments