Skip to content

Save #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 13, 2024
Merged

Save #24

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/system.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ jobs:
run: uv run commit0 test-reference simpy tests/test_event.py::test_succeed
- name: Evaluate
run: uv run commit0 evaluate-reference simpy
- name: Save
env:
GITHUB_TOKEN: ${{ secrets.MY_GITHUB_TOKEN }}
run: |
uv run commit0 save simpy test-save-commit0
16 changes: 14 additions & 2 deletions commit0/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import commit0.harness.build
import commit0.harness.setup
import commit0.harness.evaluate
import commit0.harness.save
import copy
import sys
import os
Expand All @@ -29,8 +30,8 @@ def main() -> None:
# after hydra gets all configs, put command-line arguments back
sys.argv = sys_argv
# repo_split: split from command line has a higher priority than split in hydra
if command in ["clone", "build", "evaluate", "evaluate-reference"]:
if len(sys.argv) == 3:
if command in ["clone", "build", "evaluate", "evaluate-reference", "save"]:
if len(sys.argv) >= 3:
if sys.argv[2] not in SPLIT:
raise ValueError(
f"repo split must be from {', '.join(SPLIT.keys())}, but you provided {sys.argv[2]}"
Expand Down Expand Up @@ -85,6 +86,17 @@ def main() -> None:
config.timeout,
config.num_workers,
)
elif command == "save":
organization = sys.argv[3]
commit0.harness.save.main(
config.dataset_name,
config.dataset_split,
config.repo_split,
config.base_dir,
organization,
config.branch,
config.github_token,
)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions commit0/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ num_workers: 8
backend: local
branch: ai
timeout: 1_800

# save related
github_token: null
4 changes: 4 additions & 0 deletions commit0/configs/config_class.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional


@dataclass
Expand All @@ -21,3 +22,6 @@ class Commit0Config:
branch: str
# timeout for running pytest
timeout: int

# save related
github_token: Optional[str]
1 change: 1 addition & 0 deletions commit0/harness/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class RepoInstance(TypedDict):
"get-tests",
"evaluate",
"evaluate-reference",
"save",
]
# repo splits
SPLIT_MINITORCH = ["minitorch"]
Expand Down
83 changes: 83 additions & 0 deletions commit0/harness/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging
import os

import git

from datasets import load_dataset
from typing import Iterator
from commit0.harness.constants import RepoInstance, SPLIT
from commit0.harness.utils import create_repo_on_github


logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def main(
dataset_name: str,
dataset_split: str,
repo_split: str,
base_dir: str,
organization: str,
branch: str,
github_token: str,
) -> None:
if github_token is None:
# Get GitHub token from environment variable if not provided
github_token = os.environ.get("GITHUB_TOKEN")
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
for example in dataset:
repo_name = example["repo"].split("/")[-1]
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
continue
local_repo_path = f"{base_dir}/{repo_name}"
github_repo_url = f"https://github.com/{organization}/{repo_name}.git"
github_repo_url = github_repo_url.replace(
"https://", f"https://x-access-token:{github_token}@"
)

# Initialize the local repository if it is not already initialized
if not os.path.exists(local_repo_path):
raise OSError(f"{local_repo_path} does not exists")
else:
repo = git.Repo(local_repo_path)

# create Github repo
create_repo_on_github(
organization=organization, repo=repo_name, logger=logger, token=github_token
)
# Add your remote repository URL
remote_name = "progress-tracker"
if remote_name not in [remote.name for remote in repo.remotes]:
repo.create_remote(remote_name, url=github_repo_url)
else:
logger.info(
f"Remote {remote_name} already exists, replacing it with {github_repo_url}"
)
repo.remote(name=remote_name).set_url(github_repo_url)

# Check if the branch already exists
if branch in repo.heads:
repo.git.checkout(branch)
else:
raise ValueError(f"The branch {branch} you want save does not exist.")

# Add all files to the repo and commit if not already committed
if not repo.is_dirty(untracked_files=True):
repo.git.add(A=True)
repo.index.commit("AI generated code.")

# Push to the GitHub repository
origin = repo.remote(name=remote_name)
try:
origin.push(refspec=f"{branch}:{branch}")
logger.info(f"Pushed to {github_repo_url} on branch {branch}")
except Exception as e:
raise Exception(
f"Push {branch} to {organization}/{repo_name} fails.\n{str(e)}"
)


__all__ = []
30 changes: 30 additions & 0 deletions commit0/harness/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import logging
import socket
import os
import time
import requests
from typing import Optional

from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
from ghapi.core import GhApi
from commit0.harness.constants import EVAL_BACKENDS


Expand Down Expand Up @@ -170,4 +175,29 @@ def create_branch(repo: git.Repo, branch: str, logger: logging.Logger) -> None:
raise RuntimeError(f"Failed to create or switch to branch '{branch}': {e}")


def create_repo_on_github(
organization: str, repo: str, logger: logging.Logger, token: Optional[str] = None
) -> None:
api = GhApi(token=token)
while True:
try:
api.repos.get(owner=organization, repo=repo) # type: ignore
logger.info(f"{organization}/{repo} already exists")
break
except HTTP403ForbiddenError:
while True:
rl = api.rate_limit.get() # type: ignore
logger.info(
f"Rate limit exceeded for the current GitHub token,"
f"waiting for 5 minutes, remaining calls: {rl.resources.core.remaining}"
)
if rl.resources.core.remaining > 0:
break
time.sleep(60 * 5)
except HTTP404NotFoundError:
api.repos.create_in_org(org=organization, name=repo) # type: ignore
logger.info(f"Created {organization}/{repo} on GitHub")
break


__all__ = []
Loading