Skip to content

Commit b8614e0

Browse files
committed
tests galore
1 parent 2abbfba commit b8614e0

File tree

1 file changed

+169
-2
lines changed

1 file changed

+169
-2
lines changed

tests/test_instrument_line_profiler.py

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from pathlib import Path
3+
from tempfile import TemporaryDirectory
34

45
from codeflash.code_utils.line_profile_utils import add_decorator_imports
56
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@@ -12,7 +13,6 @@ def test_add_decorator_imports_helper_in_class():
1213
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_classmethod.py").resolve()
1314
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
1415
project_root_path = (Path(__file__).parent / "..").resolve()
15-
original_cwd = Path.cwd()
1616
run_cwd = Path(__file__).parent.parent.resolve()
1717
test_config = TestConfig(
1818
tests_root=tests_root,
@@ -78,10 +78,10 @@ def helper(self, arr, j):
7878
)
7979

8080
def test_add_decorator_imports_helper_in_nested_class():
81+
#Need to invert the assert once the helper detection is fixed
8182
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_nested_classmethod.py").resolve()
8283
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
8384
project_root_path = (Path(__file__).parent / "..").resolve()
84-
original_cwd = Path.cwd()
8585
run_cwd = Path(__file__).parent.parent.resolve()
8686
test_config = TestConfig(
8787
tests_root=tests_root,
@@ -153,3 +153,170 @@ def helper(self, arr, j):
153153
func_optimizer.write_code_and_helpers(
154154
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
155155
)
156+
157+
def test_add_decorator_imports_nodeps():
158+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
159+
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
160+
project_root_path = (Path(__file__).parent / "..").resolve()
161+
run_cwd = Path(__file__).parent.parent.resolve()
162+
test_config = TestConfig(
163+
tests_root=tests_root,
164+
tests_project_rootdir=project_root_path,
165+
project_root_path=project_root_path,
166+
test_framework="pytest",
167+
pytest_cmd="pytest",
168+
)
169+
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
170+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
171+
os.chdir(run_cwd)
172+
#func_optimizer = pass
173+
try:
174+
ctx_result = func_optimizer.get_code_optimization_context()
175+
code_context: CodeOptimizationContext = ctx_result.unwrap()
176+
original_helper_code: dict[Path, str] = {}
177+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
178+
for helper_function_path in helper_function_paths:
179+
with helper_function_path.open(encoding="utf8") as f:
180+
helper_code = f.read()
181+
original_helper_code[helper_function_path] = helper_code
182+
line_profiler_output_file = add_decorator_imports(
183+
func_optimizer.function_to_optimize, code_context)
184+
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
185+
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
186+
187+
188+
@codeflash_line_profile
189+
def sorter(arr):
190+
print("codeflash stdout: Sorting list")
191+
for i in range(len(arr)):
192+
for j in range(len(arr) - 1):
193+
if arr[j] > arr[j + 1]:
194+
temp = arr[j]
195+
arr[j] = arr[j + 1]
196+
arr[j + 1] = temp
197+
print(f"result: {{arr}}")
198+
return arr
199+
"""
200+
assert code_path.read_text("utf-8") == expected_code_main
201+
finally:
202+
func_optimizer.write_code_and_helpers(
203+
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
204+
)
205+
206+
def test_add_decorator_imports_helper_outside():
207+
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_deps.py").resolve()
208+
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
209+
project_root_path = (Path(__file__).parent / "..").resolve()
210+
run_cwd = Path(__file__).parent.parent.resolve()
211+
test_config = TestConfig(
212+
tests_root=tests_root,
213+
tests_project_rootdir=project_root_path,
214+
project_root_path=project_root_path,
215+
test_framework="pytest",
216+
pytest_cmd="pytest",
217+
)
218+
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
219+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
220+
os.chdir(run_cwd)
221+
#func_optimizer = pass
222+
try:
223+
ctx_result = func_optimizer.get_code_optimization_context()
224+
code_context: CodeOptimizationContext = ctx_result.unwrap()
225+
original_helper_code: dict[Path, str] = {}
226+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
227+
for helper_function_path in helper_function_paths:
228+
with helper_function_path.open(encoding="utf8") as f:
229+
helper_code = f.read()
230+
original_helper_code[helper_function_path] = helper_code
231+
line_profiler_output_file = add_decorator_imports(
232+
func_optimizer.function_to_optimize, code_context)
233+
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
234+
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
235+
236+
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
237+
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
238+
239+
240+
@codeflash_line_profile
241+
def sorter_deps(arr):
242+
for i in range(len(arr)):
243+
for j in range(len(arr) - 1):
244+
if dep1_comparer(arr, j):
245+
dep2_swap(arr, j)
246+
return arr
247+
248+
"""
249+
expected_code_helper1 = """from line_profiler import profile as codeflash_line_profile
250+
251+
252+
@codeflash_line_profile
253+
def dep1_comparer(arr, j: int) -> bool:
254+
return arr[j] > arr[j + 1]
255+
"""
256+
expected_code_helper2="""from line_profiler import profile as codeflash_line_profile
257+
258+
259+
@codeflash_line_profile
260+
def dep2_swap(arr, j):
261+
temp = arr[j]
262+
arr[j] = arr[j + 1]
263+
arr[j + 1] = temp
264+
"""
265+
assert code_path.read_text("utf-8") == expected_code_main
266+
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper1
267+
assert code_context.helper_functions[1].file_path.read_text("utf-8") == expected_code_helper2
268+
finally:
269+
func_optimizer.write_code_and_helpers(
270+
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
271+
)
272+
273+
def test_add_decorator_imports_helper_in_dunder_class():
274+
code_str = """def sorter(arr):
275+
ans = helper(arr)
276+
return ans
277+
class helper:
278+
def __init__(self, arr):
279+
return arr.sort()"""
280+
code_path = TemporaryDirectory()
281+
code_write_path = Path(code_path.name) / "dunder_class.py"
282+
code_write_path.write_text(code_str,"utf-8")
283+
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
284+
project_root_path = Path(code_path.name)
285+
run_cwd = Path(__file__).parent.parent.resolve()
286+
test_config = TestConfig(
287+
tests_root=tests_root,
288+
tests_project_rootdir=project_root_path,
289+
project_root_path=project_root_path,
290+
test_framework="pytest",
291+
pytest_cmd="pytest",
292+
)
293+
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
294+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
295+
os.chdir(run_cwd)
296+
#func_optimizer = pass
297+
try:
298+
ctx_result = func_optimizer.get_code_optimization_context()
299+
code_context: CodeOptimizationContext = ctx_result.unwrap()
300+
original_helper_code: dict[Path, str] = {}
301+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
302+
for helper_function_path in helper_function_paths:
303+
with helper_function_path.open(encoding="utf8") as f:
304+
helper_code = f.read()
305+
original_helper_code[helper_function_path] = helper_code
306+
line_profiler_output_file = add_decorator_imports(
307+
func_optimizer.function_to_optimize, code_context)
308+
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
309+
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
310+
311+
312+
@codeflash_line_profile
313+
def sorter(arr):
314+
ans = helper(arr)
315+
return ans
316+
class helper:
317+
def __init__(self, arr):
318+
return arr.sort()
319+
"""
320+
assert code_write_path.read_text("utf-8") == expected_code_main
321+
finally:
322+
pass

0 commit comments

Comments
 (0)