Skip to content

[Feat] add staging argument #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.version import __version__

if TYPE_CHECKING:
from requests import Response

from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.result.explanation import Explanation

from packaging import version

if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
Expand Down Expand Up @@ -175,6 +177,57 @@ def create_pr(
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)


def create_staging(
original_code: str,
new_code: str,
explanation: Explanation,
existing_tests_source: str,
generated_original_test_source: str,
function_trace_id: str,
coverage_message: str,
) -> Response:
"""Create a staging pull request, targeting the specified branch. (usually 'staging').

:param owner: The owner of the repository.
:param repo: The name of the repository.
:param base_branch: The base branch to target.
:param file_changes: A dictionary of file changes.
:param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc.
:param generated_tests: The generated tests.
:return: The response object.
"""
# convert Path objects to strings
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()

build_file_changes = {
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
}
payload = {
"baseBranch": get_current_branch(),
"diffContents": build_file_changes,
"prCommentFields": PrComment(
optimization_explanation=explanation.explanation_message(),
best_runtime=explanation.best_runtime_ns,
original_runtime=explanation.original_runtime_ns,
function_name=explanation.function_name,
relative_file_path=relative_path,
speedup_x=explanation.speedup_x,
speedup_pct=explanation.speedup_pct,
winning_behavioral_test_results=explanation.winning_behavioral_test_results,
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
benchmark_details=explanation.benchmark_details,
).to_json(),
"existingTests": existing_tests_source,
"generatedTests": generated_original_test_source,
"traceId": function_trace_id,
"coverage_message": coverage_message,
}
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)


def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
"""Check if the Codeflash GitHub App is installed on the specified repository.

Expand Down
1 change: 1 addition & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def parse_args() -> Namespace:
parser.add_argument(
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
parser.add_argument(
"--verify-setup",
action="store_true",
Expand Down
147 changes: 91 additions & 56 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from rich.tree import Tree

from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success
from codeflash.api.cfapi import add_code_context_hash, create_staging, mark_optimization_success
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
Expand Down Expand Up @@ -997,64 +997,99 @@ def find_and_process_best_optimization(
original_code_combined[explanation.file_path] = self.function_to_optimize_source_code
new_code_combined = new_helper_code.copy()
new_code_combined[explanation.file_path] = new_code
if not self.args.no_pr:
coverage_message = (
original_code_baseline.coverage_results.build_message()
if original_code_baseline.coverage_results
else "Coverage data not available"
)
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)
original_runtime_by_test = (
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
)
optimized_runtime_by_test = (
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
)
# Add runtime comments to generated tests before creating the PR
generated_tests = add_runtime_comments_to_generated_tests(
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
)
generated_tests_str = "\n\n".join(
[test.generated_original_test_source for test in generated_tests.generated_tests]
)
existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
test_cfg=self.test_cfg,
original_runtimes_all=original_runtime_by_test,
optimized_runtimes_all=optimized_runtime_by_test,
)
if concolic_test_str:
generated_tests_str += "\n\n" + concolic_test_str

check_create_pr(
original_code=original_code_combined,
new_code=new_code_combined,
explanation=explanation,
existing_tests_source=existing_tests,
generated_original_test_source=generated_tests_str,
function_trace_id=self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id,
coverage_message=coverage_message,
git_remote=self.args.git_remote,
)
if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function):
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
)
else:
# Mark optimization success since no PR will be created
mark_optimization_success(
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
)
self.process_review(
original_code_baseline,
best_optimization,
generated_tests,
test_functions_to_remove,
concolic_test_str,
original_code_combined,
new_code_combined,
explanation,
function_to_all_tests,
exp_type,
original_helper_code,
)
self.log_successful_optimization(explanation, generated_tests, exp_type)
return best_optimization

def process_review(
self,
original_code_baseline: OriginalCodeBaseline,
best_optimization: BestOptimization,
generated_tests: GeneratedTestsList,
test_functions_to_remove: list[str],
concolic_test_str: str | None,
original_code_combined: dict[Path, str],
new_code_combined: dict[Path, str],
explanation: Explanation,
function_to_all_tests: dict[str, set[FunctionCalledInTest]],
exp_type: str,
original_helper_code: dict[Path, str],
) -> None:
coverage_message = (
original_code_baseline.coverage_results.build_message()
if original_code_baseline.coverage_results
else "Coverage data not available"
)

generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)

original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
optimized_runtime_by_test = (
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
)

generated_tests = add_runtime_comments_to_generated_tests(
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
)

generated_tests_str = "\n\n".join(
[test.generated_original_test_source for test in generated_tests.generated_tests]
)
if concolic_test_str:
generated_tests_str += "\n\n" + concolic_test_str

existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
test_cfg=self.test_cfg,
original_runtimes_all=original_runtime_by_test,
optimized_runtimes_all=optimized_runtime_by_test,
)

data = {
"original_code": original_code_combined,
"new_code": new_code_combined,
"explanation": explanation,
"existing_tests_source": existing_tests,
"generated_original_test_source": generated_tests_str,
"function_trace_id": self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id,
"coverage_message": coverage_message,
}

if not self.args.no_pr and not self.args.staging_review:
data["git_remote"] = self.args.git_remote
check_create_pr(**data)
elif self.args.staging_review:
create_staging(**data)
else:
# Mark optimization success since no PR will be created
mark_optimization_success(
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
)

if ((not self.args.no_pr) or not self.args.staging_review) and (
self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function)
):
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)

def establish_original_code_baseline(
self,
code_context: CodeOptimizationContext,
Expand Down
Loading