|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import contextlib |
| 4 | +import json |
| 5 | +import os |
3 | 6 | from dataclasses import dataclass |
4 | 7 | from pathlib import Path |
5 | 8 | from typing import TYPE_CHECKING |
@@ -55,18 +58,21 @@ def initialize_function_optimization( |
55 | 58 | return {"functionName": params.functionName, "status": "not found", "args": None} |
56 | 59 | fto = optimizable_funcs.popitem()[1][0] |
57 | 60 | server.optimizer.current_function_being_optimized = fto |
58 | | - return {"functionName": params.functionName, "status": "success", "info": fto.server_info} |
| 61 | + return {"functionName": params.functionName, "status": "success"} |
59 | 62 |
|
60 | 63 |
|
61 | 64 | @server.feature("discoverFunctionTests") |
62 | 65 | def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: |
63 | | - current_function = server.optimizer.current_function_being_optimized |
| 66 | + fto = server.optimizer.current_function_being_optimized |
| 67 | + optimizable_funcs = {fto.file_path: [fto]} |
| 68 | + |
| 69 | + devnull_writer = open(os.devnull, "w") # noqa |
| 70 | + with contextlib.redirect_stdout(devnull_writer): |
| 71 | + function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) |
64 | 72 |
|
65 | | - optimizable_funcs = {current_function.file_path: [current_function]} |
| 73 | + server.optimizer.discovered_tests = function_to_tests |
66 | 74 |
|
67 | | - function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) |
68 | | - # mocking in order to get things going |
69 | | - return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)} |
| 75 | + return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests} |
70 | 76 |
|
71 | 77 |
|
72 | 78 | @server.feature("prepareOptimization") |
@@ -145,6 +151,7 @@ def perform_function_optimization( |
145 | 151 | function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, |
146 | 152 | original_module_ast=original_module_ast, |
147 | 153 | original_module_path=current_function.file_path, |
| 154 | + function_to_tests=server.optimizer.discovered_tests or {}, |
148 | 155 | ) |
149 | 156 |
|
150 | 157 | server.optimizer.current_function_optimizer = function_optimizer |
@@ -214,13 +221,14 @@ def perform_function_optimization( |
214 | 221 | "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", |
215 | 222 | } |
216 | 223 |
|
217 | | - optimized_source = best_optimization.candidate.source_code # noqa: F841 |
| 224 | + optimized_source = best_optimization.candidate.source_code |
218 | 225 |
|
219 | 226 | return { |
220 | 227 | "functionName": params.functionName, |
221 | 228 | "status": "success", |
222 | 229 | "message": "Optimization completed successfully", |
223 | 230 | "extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster", |
| 231 | + "optimization": json.dumps(optimized_source, indent=None), |
224 | 232 | } |
225 | 233 |
|
226 | 234 |
|
|
0 commit comments