Skip to content

Commit

Permalink
SWE-Benchmarks (from Huggingface) (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins authored Mar 26, 2024
1 parent 480063a commit 3976617
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ build
benchmark_repos
docs/build
.DS_Store
benchmarks/benchmarks/swe_bench_samples
8 changes: 8 additions & 0 deletions benchmarks/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,13 @@ def common_benchmark_parser():
action="store_true",
help="Evaluate the baseline for the benchmark",
)
parser.add_argument(
"--swe_bench",
nargs="?",
const="dev",
default=None,
type=str,
help="Fetch or load SWE-bench examples from split: dev (default), train or test.",
)

return parser
10 changes: 10 additions & 0 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from benchmarks.benchmark_result import BenchmarkResult
from benchmarks.benchmark_run import BenchmarkRun
from benchmarks.run_sample import run_sample
from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples
from mentat.config import Config
from mentat.git_handler import get_git_diff, get_mentat_branch, get_mentat_hexsha
from mentat.llm_api_handler import model_context_size, prompt_tokens
Expand Down Expand Up @@ -314,6 +315,15 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1)
if __name__ == "__main__":
parser = common_benchmark_parser()
args = parser.parse_args()
if args.swe_bench:
if args.swe_bench not in {"dev", "train", "test"}:
print("Invalid SWE-Bench split.")
exit(1)
# Download and save SWE benchmarks as Samples
samples = get_swe_samples(args.swe_bench, args.max_benchmarks)
sample_titles = [sample.title for sample in samples]
args.benchmarks = sample_titles
args.directory = SWE_BENCH_SAMPLES_DIR / args.swe_bench
run_benchmarks(
args.benchmarks,
args.directory,
Expand Down
31 changes: 29 additions & 2 deletions benchmarks/run_sample.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from pathlib import Path
import subprocess
import re
from typing import Any

from mentat import Mentat
from mentat.errors import SampleError
from mentat.git_handler import get_git_diff
from mentat.parsers.git_parser import GitParser
from mentat.sampler.sample import Sample
from mentat.sampler.utils import get_active_snapshot_commit, setup_repo
from mentat.sampler.utils import get_active_snapshot_commit, setup_repo, apply_diff_to_repo
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import convert_string_to_asynciter

Expand Down Expand Up @@ -45,7 +47,10 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str,
conversation.add_model_message(content, [], parsed_llm_response)
else:
raise SampleError(f"Invalid role found in message_history: {msg['role']}")
await mentat.call_mentat_auto_accept(sample.message_prompt)
prompt = sample.message_prompt
if sample.hint_text:
prompt += f"\n{80 * '-'}\nHint Text:\n{sample.hint_text}"
await mentat.call_mentat_auto_accept(prompt)
await mentat.shutdown()

# Get the diff between pre- and post-edit
Expand All @@ -54,6 +59,27 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str,
message_eval = str(transcript_messages[-1].get("message", ""))
diff_eval = get_git_diff(commit_active or "HEAD", cwd=cwd)

test_results = {"passed": 0, "failed": 0, "error": ""}
if sample.test_command:
if sample.test_patch:
apply_diff_to_repo(sample.test_patch, repo)
try:
output = subprocess.run(
sample.test_command,
shell=True,
capture_output=True,
text=True,
cwd=cwd,
)
matches = re.search(r"(?:(\d+) passed)?(?:, )?(?:(\d+) failed)?", output.stdout)
if matches:
test_results["passed"] = int(matches.group(1)) or 0
test_results["failed"] = int(matches.group(2)) or 0
else:
raise SampleError(f"Test command failed: {output.stdout}")
except Exception as e:
test_results["error"] = str(e)

return {
"id": sample.id,
"message_eval": message_eval,
Expand All @@ -64,4 +90,5 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str,
"id": sample.id,
"messages": transcript_messages,
},
"test_results": test_results,
}
38 changes: 38 additions & 0 deletions benchmarks/swe_bench_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

from datasets import load_dataset, DatasetDict # type: ignore

from mentat.sampler.sample import Sample


SWE_BENCH_SAMPLES_DIR = Path(__file__).parent / "benchmarks" / "swe_bench_samples"


def download_swe_benchmarks(split: str = "dev") -> list[dict[str, str]]:
"""3 splits are available: dev (225), test (2.29k), and train (19k)."""
dataset: DatasetDict = load_dataset("princeton-nlp/SWE-bench", split=split) # type: ignore
dataset: list[dict[str, str]] = [dict(benchmark) for benchmark in dataset]
return dataset


def get_swe_samples(split: str = "dev", max_benchmarks: int | None = None) -> list[Sample]:
"""Return a list of SWE-Bench samples.
If missing, download, convert to Samples and save locally.
"""
split_dir = SWE_BENCH_SAMPLES_DIR / split
saved_benchmarks = list(split_dir.glob("*.json"))
if not split_dir.exists() or max_benchmarks and len(saved_benchmarks) < max_benchmarks:
print(f"Downloading {split} split from SWE-Bench...")
split_dir.mkdir(parents=True, exist_ok=True)
dataset = download_swe_benchmarks(split)
samples = [Sample.from_swe_bench(benchmark) for benchmark in dataset]
for sample in samples:
sample.save(split_dir / f"{sample.id}.json")
else:
samples = [Sample.load(fname) for fname in saved_benchmarks]

if max_benchmarks:
samples = samples[:max_benchmarks]
print(f"Selected {len(samples)} benchmarks from '{split}'")
return samples
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aiomultiprocess==0.9.0
black==23.9.1
datasets==2.18.0
gitpython==3.1.41
isort==5.12.0
pip-licenses==4.3.3
Expand Down
29 changes: 29 additions & 0 deletions mentat/sampler/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Changelog

All notable changes to this project will be documented in this file.

## [2024-03-22]: Add fields to cover content of SWE-Bench
Fields in common are left with their current name. Missing fields (*) are added.

Sampler SWE-Bench
- title
- description
- id instance_id
- parent_id
- repo repo
- environment_setup_commit *
- merge_base base_commit
- diff_merge_base
- diff_active
- message_history
- message_prompt problem_statement
hint_text *
- message_edit
- context
- diff_edit patch
test_patch *
- test_command FAIL_TO_PASS
PASS_TO_PASS *
- version
- version
- created_at
39 changes: 22 additions & 17 deletions mentat/sampler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,28 @@ In any github-connected repo:
## `Sample` API
A `Sample` captures interactions between a developer and any LLM Coding Assistant. It consists of a starting codebase, a user command, and the expected LLM response - text, a git diff, or both. It can also include a list of paths/line-numbers to be included with the prompt, diffs to setup the git environment, and more:

| Field | Req | Type | Description |
|------------------|-----|------------------------|-------------|
| title | | `str` | plaintext by creator |
| description | | `str` | plaintext by creator |
| id | | `uuid` | |
| parent_id | | `uuid` | id of sample immediately before this |
| repo | * | `str` | a url to download the code |
| merge_base | * | `str` | the latest permanent commit |
| diff_merge_base | | `str` | between merge_base and latest commit |
| diff_active | | `str` | between latest commit and active (pre-edit) code |
| args | | `list[str]` | list of `<relative_path>[:<start_line>-<end_line>]` |
| message_history | | `list[dict[str, str]]` | list of prior user and assistant messages |
| message_prompt | * | `str` | the sample task |
| message_edit | | `str` | plaintext response returned for sample edit |
| diff_edit | * | `str` | between starting (diff_head) and ending code. |
| test_command | | `str` | discrete pass/fail, e.g. ‘pytest -k diff_active’ |
| version | | `str` | current Sample API version |
| Field | Req | Type | Description |
|---------------------------|-----|------------------------|-------------|
| title | | `str` | plaintext by creator |
| description | | `str` | plaintext by creator |
| id | | `uuid` | |
| parent_id | | `uuid` | id of sample immediately before this |
| repo | * | `str` | a url to download the code |
| environment_setup_commit | | `str` | commit hash to use for environment setup and installation |
| merge_base | * | `str` | the latest permanent commit |
| diff_merge_base | | `str` | between merge_base and latest commit |
| diff_active | | `str` | between latest commit and active (pre-edit) code |
| context | | `list[str]` | list of `<relative_path>[:<start_line>-<end_line>]` |
| message_history | | `list[dict[str, str]]` | list of prior user and assistant messages |
| message_prompt | * | `str` | the sample task |
| hint_text | | `str` | extra information, e.g. github issue comments
| message_edit | | `str` | plaintext response returned for sample edit |
| diff_edit | * | `str` | between starting (diff_head) and ending code. |
| test_patch | | `str` | A patch to files used to evaluate the samples
| test_command | | `str` | discrete pass/fail, e.g. ‘pytest -k diff_active’ |
| PASS_TO_PASS | | `str` | discrete pass/fail, expected to pass

| version | | `str` | current Sample API version |

Notes:
- All diffs and code changes follow standard git-diff format (`diff --git a/new_filename...`)
Expand Down
2 changes: 1 addition & 1 deletion mentat/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.0"
__version__ = "0.3.0"
62 changes: 62 additions & 0 deletions mentat/sampler/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import re
from pathlib import Path

import attr
Expand All @@ -17,15 +18,19 @@ class Sample:
id: str = attr.field(default="")
parent_id: str = attr.field(default="")
repo: str = attr.field(default="")
environment_setup_commit: str | None = attr.field(default=None)
merge_base: str | None = attr.field(default=None)
diff_merge_base: str = attr.field(default="")
diff_active: str = attr.field(default="")
message_history: list[dict[str, str]] = attr.field(default=[]) # type: ignore
message_prompt: str = attr.field(default="")
hint_text: str = attr.field(default="")
message_edit: str = attr.field(default="")
context: list[str] = attr.field(default=[]) # type: ignore
diff_edit: str = attr.field(default="")
test_patch: str = attr.field(default="")
test_command: str = attr.field(default="")
PASS_TO_PASS: str = attr.field(default="")
version: str = attr.field(default=__version__)

def save(self, fname: str | Path) -> None:
Expand All @@ -41,8 +46,65 @@ def load(cls, fname: str | Path) -> Sample:
kwargs["message_history"] = kwargs.get("message_history", [])[::-1]
kwargs["version"] = "0.2.0"
_version = kwargs["version"]
if _version < "0.3.0":
# Additional fields from SWE-Bench
kwargs["environment_setup_commit"] = ""
kwargs["hint_text"] = ""
kwargs["test_patch"] = ""
kwargs["PASS_TO_PASS"] = ""
kwargs["version"] = "0.3.0"
if _version != __version__:
raise SampleError(
f"Warning: sample version ({_version}) does not match current" f" version ({__version__})."
)
return cls(**kwargs)

@classmethod
def from_swe_bench(cls, benchmark: dict[str, str]) -> Sample:
"""Create a Sample from a SWE-Bench benchmark.
SWE-Bench Fields (https://huggingface.co/datasets/princeton-nlp/SWE-bench#dataset-structure)
- instance_id: (str) - A formatted instance identifier, usually as repo_owner__repo_name-PR-number.
- patch: (str) - The gold patch, the patch generated by the PR (minus test-related code), that resolved
the issue.
- repo: (str) - The repository owner/name identifier from GitHub.
- base_commit: (str) - The commit hash of the repository representing the HEAD of the repository before the
solution PR is applied.
- hints_text: (str) - Comments made on the issue prior to the creation of the solution PR’s first commit
creation date.
- created_at: (str) - The creation date of the pull request.
- test_patch: (str) - A test-file patch that was contributed by the solution PR.
- problem_statement: (str) - The issue title and body.
- version: (str) - Installation version to use for running evaluation.
- environment_setup_commit: (str) - commit hash to use for environment setup and installation.
- FAIL_TO_PASS: (str) - A json list of strings that represent the set of tests resolved by the PR and tied to
the issue resolution.
- PASS_TO_PASS: (str) - A json list of strings that represent tests that should pass before and after the PR
application.
"""
patch = benchmark.get("patch", "")
edited_files = re.findall(r"diff --git a/(.*?) b/\1", patch)
return cls(
title=f"SWE-bench-{benchmark['instance_id']}",
description="",
id=benchmark["instance_id"],
parent_id="",
repo=f"https://github.com/{benchmark.get('repo')}",
environment_setup_commit=benchmark.get("environment_setup_commit", ""),
merge_base=benchmark.get("base_commit"),
diff_merge_base="",
diff_active="",
message_history=[],
message_prompt=benchmark.get("problem_statement", ""),
hint_text=benchmark.get("hint_text", ""),
message_edit="",
context=edited_files,
diff_edit=patch,
test_patch=benchmark.get("test_patch", ""),
test_command=(
"" if not benchmark.get("FAIL_TO_PASS") else "pytest " + " ".join(json.loads(benchmark["FAIL_TO_PASS"]))
),
PASS_TO_PASS=(
"" if not benchmark.get("PASS_TO_PASS") else "pytest " + " ".join(json.loads(benchmark["PASS_TO_PASS"]))
),
)
2 changes: 1 addition & 1 deletion tests/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_l
assert "test_file.py" in edits[1]
assert "+# forty two" in edits[1]
assert sample.test_command == "test_test_command"
assert sample.version == "0.2.0"
assert sample.version == "0.3.0"


test_sample = {
Expand Down

0 comments on commit 3976617

Please sign in to comment.