diff --git a/benchmarks/arg_parser.py b/benchmarks/arg_parser.py index c054842c3..93155a064 100644 --- a/benchmarks/arg_parser.py +++ b/benchmarks/arg_parser.py @@ -72,5 +72,11 @@ def common_benchmark_parser(): type=str, help="Fetch or load SWE-bench examples from split: dev (default), train or test.", ) + parser.add_argument( + "--auto_context_tokens", + default=0, + type=int, + help="Include auto-selected tokens in benchmark runs and evaluate precision/recall", + ) return parser diff --git a/benchmarks/benchmark_result.py b/benchmarks/benchmark_result.py index 04e4f3efb..cb0993b85 100644 --- a/benchmarks/benchmark_result.py +++ b/benchmarks/benchmark_result.py @@ -36,6 +36,11 @@ class BenchmarkResult: missing_functionality: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"}) extra_functionality: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"}) referenced_format: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"}) + test_eval_results: Optional[dict] = attr.ib(default=None, metadata={"display": "json"}) + test_eval_passed: Optional[bool] = attr.ib(default=None, metadata={"aggregation": "percent"}) + context_results: Optional[dict] = attr.ib(default=None, metadata={"display": "json"}) + context_precision: Optional[float] = attr.ib(default=None, metadata={"aggregation": "average"}) + context_recall: Optional[float] = attr.ib(default=None, metadata={"aggregation": "average"}) def display_color(self) -> str: if self.passed is None: diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 74b3a7cef..719636e10 100755 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -15,6 +15,7 @@ from benchmarks.arg_parser import common_benchmark_parser from benchmarks.benchmark_result import BenchmarkResult from benchmarks.benchmark_run import BenchmarkRun +from benchmarks.context_benchmark import run_auto_context_benchmark from benchmarks.run_sample import run_sample from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples from mentat.config import Config @@ -202,11 +203,12 @@ def from_module(cls, path_to_module: Path, module_name: str) -> Benchmark: return output @classmethod - def from_sample(cls, path_to_sample: Path) -> Benchmark: + def from_sample(cls, path_to_sample: Path, config: Config | None = None) -> Benchmark: sample = Sample.load(path_to_sample) return cls( title=sample.title, description=sample.description, + config=config or Config(), samples=[sample], ) @@ -223,10 +225,17 @@ async def run(self, retries: int = 1) -> list[BenchmarkResult]: family=formatted_title, ) try: - sample_result = await run_sample(sample) + if sample.context and self.config.auto_context_tokens: + score = await run_auto_context_benchmark(sample, self.config, include_context=False) + result.context_results = {**score, "auto_context_tokens": self.config.auto_context_tokens} + result.context_precision = score["precision"] + result.context_recall = score["recall"] + sample_result = await run_sample(sample, config=self.config) result.cost = sample_result["cost"] result.tokens = sample_result["tokens"] result.transcript = sample_result["transcript"] + result.test_eval_results = sample_result["test_eval_results"] + result.test_eval_passed = sample_result["test_eval_passed"] if self.verify is not None: result.verify = self.verify() @@ -251,7 +260,13 @@ def benchmark_listed(title, benchmarks): return False -def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1): +def run_benchmarks( + user_benchmarks: list[str], + directory: str, + retries: int = 1, + max_benchmarks: int | None = None, + auto_context_tokens: int = 0, +): # Load benchmarks dir_path = Path(directory).resolve() assert dir_path.exists(), f"Invalid directory: {directory}" @@ -263,7 +278,8 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1) if file.endswith(".py"): benchmark = Benchmark.from_module(path, "benchmark") elif file.endswith(".json"): - benchmark = Benchmark.from_sample(path) + config = Config(auto_context_tokens=auto_context_tokens) + benchmark = Benchmark.from_sample(path, config) else: continue @@ -277,7 +293,9 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1) results_cache = dir_path / f"benchmark_results_cache_{uuid4()}.jsonl" results_cache.touch() total_cost = 0.0 - for benchmark in benchmarks: + for i, benchmark in enumerate(benchmarks): + if max_benchmarks and i >= max_benchmarks: + break # Run benchmark.run() with timeout try: result = asyncio.run(benchmark.run(retries=retries)) @@ -328,4 +346,6 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1) args.benchmarks, args.directory, args.retries, + args.max_benchmarks, + args.auto_context_tokens, ) diff --git a/benchmarks/context_benchmark.py b/benchmarks/context_benchmark.py index 252eefbf0..599b012b7 100755 --- a/benchmarks/context_benchmark.py +++ b/benchmarks/context_benchmark.py @@ -1,193 +1,106 @@ -#!/usr/bin/env python import asyncio import json import os -from collections import defaultdict -from itertools import islice +from datetime import datetime from pathlib import Path from typing import Any -from git import Repo - from benchmarks.arg_parser import common_benchmark_parser -from mentat.code_context import CodeContext -from mentat.code_feature import CodeFeature -from mentat.code_file_manager import CodeFileManager +from benchmarks.run_sample import setup_sample +from benchmarks.swe_bench_runner import get_swe_samples, SWE_BENCH_SAMPLES_DIR +from mentat import Mentat from mentat.config import Config -from mentat.cost_tracker import CostTracker -from mentat.llm_api_handler import count_tokens, model_context_size -from mentat.sampler.utils import clone_repo -from mentat.session_context import SESSION_CONTEXT, SessionContext - - -class MockStream: - def send(self, message, **kwargs): - end = kwargs.get("end", "\n") - print(message, end=end) - - -def _load_benchmarks() -> dict[str, dict[str, Any]]: - """Load all benchmarks found in benchmark_repos""" - benchmarks = {} - benchmarks_dir = Path(__file__).parent / "../benchmark_repos" - for repo_dir in benchmarks_dir.iterdir(): - benchmarks_path = repo_dir / "benchmarks.json" - if benchmarks_path.exists(): - with open(benchmarks_path, "r") as f: - benchmarks.update(json.load(f)) - return benchmarks - - -def _convert_features_to_line_sets(git_root: Path, features: list[CodeFeature]) -> defaultdict[set]: - """Convert a list of features to a dict of {path: set(lines)} for comparison""" - lines = defaultdict(set) - for feature in features: - # Non-explicit features (e.g. CodeMaps) are considered false positives. - # Using negative numbers here as that affect. - - path = feature.path.relative_to(git_root) - interval = feature.interval - lines[path].update(range(interval.start, interval.end + 1)) - return lines - - -def evaluate( - git_root: Path, - actual: list[CodeFeature], - expected: list[CodeFeature], -) -> dict[str, float]: - """Compare two lists of features and return precision, recall and f1 scores""" - actual_lines = _convert_features_to_line_sets(git_root, actual) - expected_lines = _convert_features_to_line_sets(git_root, expected) - - _TP, _FP, _FN = 0, 0, 0 - for file in actual_lines | expected_lines: - actual_set = actual_lines[file] - expected_set = expected_lines[file] - _TP += len(actual_set & expected_set) - _FP += len(actual_set - expected_set) - _FN += len(expected_set - actual_set) +from mentat.sampler.sample import Sample +from mentat.session_context import SESSION_CONTEXT - precision, recall, f1 = None, None, None - if (_TP + _FP) > 0: - precision = _TP / (_TP + _FP) - if (_TP + _FN) > 0: - recall = _TP / (_TP + _FN) - if precision and recall: - f1 = 2 * precision * recall / (precision + recall) - return {"precision": precision, "recall": recall, "f1": f1} +def _score(predicted: set[Path], expected: set[Path]) -> dict[str, Any]: + true_positives = predicted.intersection(expected) + false_positives = predicted.difference(expected) + false_negatives = expected.difference(predicted) + precision = len(true_positives) / (len(true_positives) + len(false_positives)) + recall = len(true_positives) / (len(true_positives) + len(false_negatives)) + return {"precision": precision, "recall": recall, "n_true": len(expected)} -async def select_features_for_benchmark(session_context, benchmark, eval=True, use_expected=False, use_llm=True): - """Select features for benchmark using expected edits as a guide""" - git_root = session_context.git_root - config = session_context.config - parser = config.parser - code_context = session_context.code_context +async def run_auto_context_benchmark( + sample: Sample, config: Config, cwd: Path | str | None = None, include_context: bool = False +) -> dict[str, Any]: + """Run a sample using Mentat and return the resulting diff""" + starting_dir = Path.cwd() - # The longest context that could have been included to generate expected_edits - model = config.model - mentat_prompt_tokens = count_tokens(parser.get_system_prompt(), model) - expected_edits, expected_edits_tokens = None, 0 - if use_expected: - expected_edits = benchmark["expected_edits"] - expected_edits_tokens = count_tokens(expected_edits, model) - max_context_tokens = model_context_size(model) - mentat_prompt_tokens - expected_edits_tokens - # Fill-in available context - config.auto_context_tokens = 8000 - code_context.use_llm = use_llm - await code_context.get_code_message(benchmark["prompt"], max_context_tokens, expected_edits) - git_root_length = len(str(git_root)) + 1 - selected_features = [f.ref()[git_root_length:] for f in code_context.features] - - selector_performance = {} - if eval: - edited_features = [CodeFeature(git_root / f) for f in benchmark["edited_features"]] - selector_performance = evaluate(git_root, code_context.features, edited_features) - return {"features": selected_features, "score": selector_performance} - - -async def test_code_context_performance(benchmarks, max_benchmarks=10): - """Run a set of benchmarks and evaluate performance - - Run standalone: - `./benchmarks/context_benchmark.py` - """ - # Load applicable benchmarks - all_benchmarks = _load_benchmarks() - if len(benchmarks) > 0: - benchmarks_to_run = {k: v for k, v in all_benchmarks.items() if k in benchmarks} - else: - benchmarks_to_run = dict(islice(all_benchmarks.items(), max_benchmarks)) - - # Run each one - scores = {} - for benchmark in benchmarks_to_run.values(): - print("\n" + benchmark["prompt"]) - - # Setup the cwd the same way as in generate - url = benchmark["codebase_url"] - codebase = clone_repo(url=url, local_dir_name=url.split("/")[-1], refresh=False) - os.chdir(codebase) - repo = Repo(".") - repo.git.checkout(benchmark["commit"]) - - # Initialize a full SESSION_CONTEXT - stream = MockStream() - config = Config() - code_context = CodeContext(stream, os.getcwd()) - session_context = SessionContext( - stream, - CostTracker(), - Path.cwd(), - config, - code_context, - CodeFileManager(), - None, + if not config.auto_context_tokens or not sample.context: + raise ValueError( + "In order to run the auto-context benchmark, sample.context must not " + "be empty (ground truth) and config.auto_context_tokens must be > 0." ) - SESSION_CONTEXT.set(session_context) - - # Run the benchmark and print results - scores = [] - for use_llm in [False, True]: - for use_expected in [False, True]: - try: - if not use_llm and use_expected: - continue # Not relevant - results = await select_features_for_benchmark( - session_context, - benchmark, - eval=True, - use_expected=use_expected, - use_llm=use_llm, - ) - score = { - **results["score"], - "selected_features": results["features"], - "edited_features": benchmark["edited_features"], - "use_llm": use_llm, - "use_expected": use_expected, - } - scores.append(score) - print( - f" UseExpected={use_expected}\t" - f"| LLM={use_llm}\t" - f"| Recall={(score['recall'] or 0.):.3f}\t" - f"| Precision={(score['precision'] or 0.):.3f}" - ) - except Exception as e: - print(f"Error: '{e}'; skipping") - - return scores + paths = [] if not include_context else [Path(a) for a in sample.context] + + try: + _, cwd, _, _ = setup_sample(sample, None, skip_test_exec=True) + exclude_paths = [cwd / ".venv"] + mentat = Mentat(cwd=cwd, paths=paths, exclude_paths=exclude_paths, config=config or Config()) + await mentat.startup() + await asyncio.sleep(0.01) # Required to initialize llm_api_handler for embeddings + + # TODO: If there's a conversation history, we might consider the cumulative context. + # Setup a mock for the LLM response and run the conversation until this point. + code_context = SESSION_CONTEXT.get().code_context + _ = await code_context.get_code_message(0, sample.message_prompt) + predicted = set(path.relative_to(cwd) for path in code_context.include_files.keys()) + actual = {Path(a) for a in sample.context} + score = _score(predicted, actual) + + await mentat.shutdown() + return score + finally: + os.chdir(starting_dir) + + +def main(user_samples: list[str], directory: str): + # Load benchmarks + dir_path = Path(directory).resolve() + assert dir_path.exists(), f"Invalid directory: {directory}" + print(f"Running benchmarks from {dir_path}") + samples: list[Sample] = [] + for root, dirs, files in os.walk(dir_path): + for file in files: + path = Path(root) / file + if file.endswith(".json"): + sample = Sample.load(path) + else: + continue + if user_samples and not any(s in sample.title for s in user_samples): + continue + samples.append(sample) + print("Found Samples:\n" + "\n".join(s.title for s in samples)) + print("*" * 80) + + config = Config(auto_context_tokens=8000) + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + results_path = dir_path / f"context_benchmark_results_{timestamp}.jsonl" + for sample in samples: + print(f"Running benchmark for {sample.title}") + accuracy = asyncio.run(run_auto_context_benchmark(sample, config, cwd=dir_path)) + print(f"Results: {accuracy}") + print("*" * 80) + with open(results_path, "a") as f: + f.write(json.dumps({sample.id: accuracy}) + "\n") if __name__ == "__main__": parser = common_benchmark_parser() args = parser.parse_args() - asyncio.run( - test_code_context_performance( - args.benchmarks, - args.max_benchmarks, - ) + if args.swe_bench: + if args.swe_bench not in {"dev", "train", "test"}: + print("Invalid SWE-Bench split.") + exit(1) + # Download and save SWE benchmarks as Samples + samples = get_swe_samples(args.swe_bench, args.max_benchmarks) + sample_titles = [sample.title for sample in samples] + args.benchmarks = sample_titles + args.directory = SWE_BENCH_SAMPLES_DIR / args.swe_bench + main( + args.benchmarks, + args.directory, ) diff --git a/benchmarks/run_sample.py b/benchmarks/run_sample.py index deafa0e5d..7634745ab 100644 --- a/benchmarks/run_sample.py +++ b/benchmarks/run_sample.py @@ -1,38 +1,71 @@ -from pathlib import Path +import json +import random import subprocess -import re +import sys +from pathlib import Path from typing import Any +from git import Repo +import tqdm + from mentat import Mentat +from mentat.config import Config from mentat.errors import SampleError from mentat.git_handler import get_git_diff from mentat.parsers.git_parser import GitParser from mentat.sampler.sample import Sample from mentat.sampler.utils import get_active_snapshot_commit, setup_repo, apply_diff_to_repo from mentat.session_context import SESSION_CONTEXT -from mentat.utils import convert_string_to_asynciter -async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, Any]: - """Run a sample using Mentat and return the resulting diff""" - +def setup_sample( + sample: Sample, cwd: Path | str | None, skip_test_exec: bool = False +) -> tuple[Repo, Path, str, str | None]: + setup_commit = sample.environment_setup_commit or sample.merge_base repo = setup_repo( url=sample.repo, cwd=cwd, - commit=sample.merge_base, + commit=setup_commit, diff_merge_base=sample.diff_merge_base, diff_active=sample.diff_active, ) cwd = Path(repo.working_dir) - # Make a commit from the pre-edited state (should match diff_active) - commit_active = get_active_snapshot_commit(repo) + test_executable = None + if not skip_test_exec and (sample.FAIL_TO_PASS or sample.PASS_TO_PASS): + # If there's an environment_setup_commit, this is what it's needed for. + try: + test_executable = get_test_executable(Path(repo.working_dir)) + except SampleError as e: + print(f"Error setting up virtual environment: {e}") + print("Using default python executable instead.") + if not test_executable: + test_executable = sys.executable + + if sample.environment_setup_commit and sample.merge_base: + # SWE-Bench samples have an environmental_setup_commit (above), + # then a merge_base to checkout. + repo.git.reset("--hard") + repo.git.checkout(sample.merge_base) + commit_active = sample.merge_base + else: + # Mentat Samples have an active diff which was set in setup_repo, + # so here create a snapshot commit (to generate diff_edit against later) + commit_active = get_active_snapshot_commit(repo) + + return repo, cwd, test_executable, commit_active + + +async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Config | None = None) -> dict[str, Any]: + """Run a sample using Mentat and return the resulting diff""" + + repo, cwd, test_executable, commit_active = setup_sample(sample, cwd) # Run sample in PythonClient paths = list[Path]() for a in sample.context: paths.append(Path(a)) - mentat = Mentat(cwd=cwd, paths=paths) + mentat = Mentat(cwd=cwd, paths=paths, config=config or Config()) await mentat.startup() session_context = SESSION_CONTEXT.get() conversation = session_context.conversation @@ -41,8 +74,7 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, if msg["role"] == "user": conversation.add_user_message(msg["content"]) elif msg["role"] == "assistant": - generator = convert_string_to_asynciter(msg["content"], 100) - parsed_llm_response = await GitParser().stream_and_parse_llm_response(generator) + parsed_llm_response = GitParser().parse_llm_response(msg["content"]) content = session_context.config.parser.file_edits_to_llm_message(parsed_llm_response) conversation.add_model_message(content, [], parsed_llm_response) else: @@ -59,26 +91,28 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, message_eval = str(transcript_messages[-1].get("message", "")) diff_eval = get_git_diff(commit_active or "HEAD", cwd=cwd) - test_results = {"passed": 0, "failed": 0, "error": ""} - if sample.test_command: - if sample.test_patch: - apply_diff_to_repo(sample.test_patch, repo) - try: - output = subprocess.run( - sample.test_command, - shell=True, - capture_output=True, - text=True, - cwd=cwd, - ) - matches = re.search(r"(?:(\d+) passed)?(?:, )?(?:(\d+) failed)?", output.stdout) - if matches: - test_results["passed"] = int(matches.group(1)) or 0 - test_results["failed"] = int(matches.group(2)) or 0 - else: - raise SampleError(f"Test command failed: {output.stdout}") - except Exception as e: - test_results["error"] = str(e) + test_results = None + test_passed = None + if sample.test_patch: + apply_diff_to_repo(sample.test_patch, repo) + if sample.FAIL_TO_PASS: + tests = json.loads(sample.FAIL_TO_PASS) + total = len(tests) + passed = 0 + errors = list[dict[str, str]]() + for test in tests: + _passed, _error = get_test_result(test, cwd, test_executable) + if _passed: + passed += 1 + if _error: + errors.append({"test": test, "error": _error}) + test_results = { + "passed": passed, + "total": total, + "passed_percent": passed / total * 100, + # "errors": errors, # Too big, but useful for debugging + } + test_passed = passed == total return { "id": sample.id, @@ -90,5 +124,213 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, "id": sample.id, "messages": transcript_messages, }, - "test_results": test_results, + "test_eval_results": test_results, + "test_eval_passed": test_passed, } + + +test_requirements_for_repo = { + "pvlib-python": [ + "setuptools", + "pytest", + "pytest-cov", + "pytest-mock", + "requests-mock", + "pytest-timeout", + "pytest-rerunfailures", + "pytest-remotedata", + ], + "pydicom": ["setuptools", "pytest"], + "sqlfluff": [ + "setuptools", + "pytest", + "pytest-cov", + "pytest-mock", + "Jinja2", + "oyaml", + ], + "pyvista": [ + "setuptools", + "pytest", + "ipython", + "ipywidgets", + "ipykernel", + "tqdm", + ], + "astroid": [ + "setuptools", + "pytest", + "attrs", + "types-attrs", + "nose", + "numpy", + "python-dateutil", + "types-python-dateutil", + "six", + "types-six", + ], + "marshmallow": [ + "setuptools", + "pytest", + "pytz", + "simplejson", + ], +} + + +def get_test_executable(cwd: Path) -> str: + """Rebuild every time with the latest setup.""" + + venv_dir = cwd / ".venv" + repo_name = cwd.name + + try: + python_executable = "python3" if sys.platform != "win32" else "python" + subprocess.run([python_executable, "-m", "venv", str(venv_dir)], check=True, cwd=cwd, capture_output=True) + except Exception as e: + raise SampleError(f"Error creating virtual environment: {e}") + + # Install as a pip module + try: + output = subprocess.run( + [venv_dir / "bin" / "pip", "install", "-e", "."], check=True, cwd=cwd, capture_output=True + ) + if output.returncode != 0: + raise SampleError(f"Error installing sample as a pip module: {output.stderr}") + except Exception as e: + raise SampleError(f"Error installing sample as a pip module: {e}") + + # Requirements are hard-coded by repo + if repo_name not in test_requirements_for_repo: + raise SampleError(f"No requirements found for repo '{repo_name}'") + requirements = test_requirements_for_repo[repo_name] + + # Install them all with pip + try: + output = subprocess.run( + [venv_dir / "bin" / "pip", "install", *list(requirements)], check=True, cwd=cwd, capture_output=True + ) + if output.returncode != 0: + raise SampleError(f"Error installing requirements: {output.stderr}") + except Exception as e: + raise SampleError(f"Error installing requirements: {e}") + + return str(venv_dir / "bin" / "python") + + +def get_test_result(test: str, cwd: Path, test_executable: str) -> tuple[bool, str]: + passed, error = False, "" + command = [test_executable, "-m", "pytest"] + if "[" in test: + # Some tests include parameters, like "..is_valid[3.1415]". + # Running '-k' over the whole suite is very slow. + path, params = test.split("[", 1) + params = params[:-1] # Remove trailing ']' + command += [path, "-k", params] + else: + command += [test] + try: + output = subprocess.run( + command, + capture_output=True, + text=True, + cwd=cwd, + ) + if (output.returncode != 0 and output.stderr) or not output.stdout: + raise SampleError(f"Test command failed: {output.stderr}") + + # Starting from the end, find the first line that contains "passed" or "failed" + lines = output.stdout.splitlines() + result_line = next( + line for line in reversed(lines) if any(key in line for key in {"passed", "failed", "skipped"}) + ) + _passed = "passed" in result_line or "skipped" in result_line + _failed = "failed" in result_line + if _passed == _failed: + raise SampleError(f"Could not determine test result from line: {result_line}") + passed = _passed + if _failed: + raise SampleError("Test failed:\n" + "\n".join(lines)) + except (SampleError, StopIteration, Exception) as e: + error = str(e) + return passed, error + + +def validate_test_fields(sample: Sample) -> dict[str, Any]: + test_results: dict[str, Any] = { + "PASS_TO_PASS": {"passed": 0, "total": 0, "errors": []}, + "FAIL_TO_PASS_PRE": {"passed": 0, "total": 0, "errors": []}, + "FAIL_TO_PASS_POST": {"passed": 0, "total": 0, "errors": []}, + } + + if not sample.FAIL_TO_PASS and not sample.PASS_TO_PASS: + return test_results + + repo, cwd, test_executable, _ = setup_sample(sample, None) + + # Run the PASS_TO_PASS test, expected to PASS + if sample.PASS_TO_PASS: + tests = json.loads(sample.PASS_TO_PASS) + + # There are sometimes hundreds of tests which take ~30 minutes to all complete. + # Since we'll check all the FAIL_TO_PASS tests, here we just want to confirm the + # environment is set up correctly, so we sample 10 tests. + tests = random.sample(tests, min(10, len(tests))) + + # Iterate with tqdm + for test in tqdm.tqdm(tests, desc="PASS_TO_PASS tests", unit="test"): + test_results["PASS_TO_PASS"]["total"] += 1 + passed, error = False, "" + try: + passed, error = get_test_result(test, cwd, test_executable) + if passed: + test_results["PASS_TO_PASS"]["passed"] += 1 + elif error: + raise SampleError(error) + except SampleError as e: + test_results["PASS_TO_PASS"]["errors"].append({"test": test, "error": str(e)}) + print("PASS_TO_PASS results: ", test_results["PASS_TO_PASS"]) + + # Apply test patch + if sample.test_patch: + print("Applying test patch...") + apply_diff_to_repo(sample.test_patch, repo) + + # Run FAIL_TO_PASS tests expected to FAIL + if sample.FAIL_TO_PASS: + tests = json.loads(sample.FAIL_TO_PASS) + for test in tqdm.tqdm(tests, desc="FAIL_TO_PASS tests", unit="test"): + test_results["FAIL_TO_PASS_PRE"]["total"] += 1 + passed, error = False, "" + try: + passed, error = get_test_result(test, cwd, test_executable) + if passed: + test_results["FAIL_TO_PASS_PRE"]["passed"] += 1 + elif error: + raise SampleError(error) + except SampleError as e: + test_results["FAIL_TO_PASS_PRE"]["errors"].append({"test": test, "error": str(e)}) + print("FAIL_TO_PASS_PRE results: ", test_results["FAIL_TO_PASS_PRE"]) + + # Apply golden patch + if sample.diff_edit: + print("Applying diff_edit...") + apply_diff_to_repo(sample.diff_edit, repo) + + # Run FAIL_TO_PASS tests expected to PASS + if sample.FAIL_TO_PASS: + tests = json.loads(sample.FAIL_TO_PASS) + for test in tqdm.tqdm(tests, desc="FAIL_TO_PASS tests", unit="test"): + test_results["FAIL_TO_PASS_POST"]["total"] += 1 + passed, error = False, "" + try: + passed, error = get_test_result(test, cwd, test_executable) + if passed: + test_results["FAIL_TO_PASS_POST"]["passed"] += 1 + elif error: + raise SampleError(error) + except SampleError as e: + test_results["FAIL_TO_PASS_POST"]["errors"].append({"test": test, "error": str(e)}) + print("FAIL_TO_PASS_POST results: ", test_results["FAIL_TO_PASS_POST"]) + + return test_results diff --git a/benchmarks/swe_bench_runner.py b/benchmarks/swe_bench_runner.py index b707f7c83..94360d084 100644 --- a/benchmarks/swe_bench_runner.py +++ b/benchmarks/swe_bench_runner.py @@ -1,15 +1,35 @@ +""" +NOTE: Not all of the Samples are valid with our current implementation. The list of +valid samples is saved in the `summoning-the-shoggoth` repo for now. Running this file +directly from the command line will run the full validation script and overwrite the +results there (takes a few hours). +""" +import argparse +import json +import os from pathlib import Path +from typing import Any from datasets import load_dataset, DatasetDict # type: ignore from mentat.sampler.sample import Sample +from benchmarks.run_sample import validate_test_fields SWE_BENCH_SAMPLES_DIR = Path(__file__).parent / "benchmarks" / "swe_bench_samples" +SWE_VALIDATION_RESULTS_PATH = ( + Path(__file__).parent.parent.parent + / "summoning-the-shoggoth" + / "swe_bench" + / "swe_bench_validation_results_2024-03-29.json" +) def download_swe_benchmarks(split: str = "dev") -> list[dict[str, str]]: - """3 splits are available: dev (225), test (2.29k), and train (19k).""" + """Get raw SWE-Bench json samples from huggingface + + 3 splits are available: dev (225), test (2.29k), and train (19k). + """ dataset: DatasetDict = load_dataset("princeton-nlp/SWE-bench", split=split) # type: ignore dataset: list[dict[str, str]] = [dict(benchmark) for benchmark in dataset] return dataset @@ -32,7 +52,79 @@ def get_swe_samples(split: str = "dev", max_benchmarks: int | None = None) -> li else: samples = [Sample.load(fname) for fname in saved_benchmarks] + # Check that samples are valid + valid_samples = list[Sample]() + if not SWE_VALIDATION_RESULTS_PATH.exists(): + print(f"Sample validation results not found at {SWE_VALIDATION_RESULTS_PATH}.") + print("Validating SWE samples...") + print("\033[93m" + "Warning: This will take a couple hours." + "\033[0m") + # This takes a couple hours. + validate_swe_samples() + with open(SWE_VALIDATION_RESULTS_PATH, "r") as f: + swe_validation_results = json.load(f) + for sample in samples: + results = swe_validation_results.get(sample.title) + pass_to_pass = ( + "PASS_TO_PASS" in results and results["PASS_TO_PASS"]["passed"] == results["PASS_TO_PASS"]["total"] + ) + fail_to_pass_post = ( + "FAIL_TO_PASS_POST" in results + and results["FAIL_TO_PASS_POST"]["passed"] == results["FAIL_TO_PASS_POST"]["total"] + ) + if pass_to_pass and fail_to_pass_post: + valid_samples.append(sample) + samples = valid_samples + if max_benchmarks: samples = samples[:max_benchmarks] print(f"Selected {len(samples)} benchmarks from '{split}'") return samples + + +def validate_swe_samples(targets: list[str] | None = None, refresh: bool = True) -> None: + """Setup each sample and run its validation tests.""" + cwd = Path.cwd() + samples = get_swe_samples() + + if SWE_VALIDATION_RESULTS_PATH.exists(): + with open(SWE_VALIDATION_RESULTS_PATH, "r") as f: + results = json.load(f) + print(f"Loaded {len(results)} previous validation results.") + else: + SWE_VALIDATION_RESULTS_PATH.mkdir(parents=True, exist_ok=True) + SWE_VALIDATION_RESULTS_PATH.touch() + results = dict[str, Any]() + for sample in samples: + os.chdir(cwd) + + if targets and not any(b in sample.title for b in targets): + continue + if not refresh and sample.title in results: + continue + try: + print(80 * "*" + f"\nValidating {sample.id}...") + test_results = validate_test_fields(sample) + percentages = dict[str, float]() + for category, result in test_results.items(): + _passed, _total = result.get("passed", 0), result.get("total", 0) + percentage = 0 if _total == 0 else _passed / _total * 100 + percentages[category] = percentage + for category, percent in percentages.items(): + expected = 0 if "_PRE" in category else 100 + print(f"{category}: {percent:.2f}% (expected {expected}%)") + results[sample.title] = test_results + except Exception as e: + print(f"Error: {e}") + results[sample.title] = {"error": str(e)} + finally: + with open(SWE_VALIDATION_RESULTS_PATH, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("targets", nargs="*") + parser.add_argument("--refresh", "-r", action="store_true") + parsed_args = parser.parse_args() + + validate_swe_samples(targets=[str(arg) for arg in parsed_args.targets], refresh=parsed_args.refresh) diff --git a/mentat/code_feature.py b/mentat/code_feature.py index 17fb71796..e41a78292 100644 --- a/mentat/code_feature.py +++ b/mentat/code_feature.py @@ -174,12 +174,12 @@ async def count_feature_tokens(features: list[CodeFeature], model: str) -> list[ """Return the number of tokens in each feature.""" sem = asyncio.Semaphore(10) - async def _count_tokens(feature: CodeFeature) -> int: + feature_tokens = list[int]() + for feature in features: async with sem: - return feature.count_tokens(model) - - tasks = [_count_tokens(f) for f in features] - return await asyncio.gather(*tasks) + tokens = feature.count_tokens(model) + feature_tokens.append(tokens) + return feature_tokens def _get_code_message_from_intervals(features: list[CodeFeature]) -> list[str]: diff --git a/mentat/embeddings.py b/mentat/embeddings.py index 8242315a3..801e9fad9 100644 --- a/mentat/embeddings.py +++ b/mentat/embeddings.py @@ -10,7 +10,7 @@ from mentat.session_input import ask_yes_no from mentat.utils import mentat_dir_path -EMBEDDINGS_API_BATCH_SIZE = 2048 +EMBEDDINGS_API_BATCH_SIZE = 1536 client = chromadb.PersistentClient(path=str(mentat_dir_path / "chroma")) @@ -64,7 +64,7 @@ def query(self, prompt: str, checksums: list[str]) -> dict[str, float]: results = self._collection.query( # type: ignore query_texts=[prompt], where={"active": True}, - n_results=len(checksums) + 1, + n_results=len(checksums), ) self._collection.update( # type: ignore ids=checksums, @@ -99,6 +99,7 @@ async def get_feature_similarity_scores( # Identify which items need embeddings. checksums: list[str] = [f.get_checksum() for f in features] + ignored_checksums = set[str]() tokens: list[int] = await count_feature_tokens(features, embedding_model) embed_texts = list[str]() embed_checksums = list[str]() @@ -110,6 +111,7 @@ async def get_feature_similarity_scores( f" maximum of {max_model_tokens} for model {config.embedding_model}." " Skipping." ) + ignored_checksums.add(checksum) continue if not collection.exists(checksum) and checksum not in embed_checksums: embed_texts.append("\n".join(feature.get_code_message())) @@ -145,6 +147,7 @@ async def get_feature_similarity_scores( # Get similarity scores stream.send(None, channel="loading", terminate=True) - _checksums = list(set(checksums)) + _checksums = list(c for c in set(checksums) if c not in ignored_checksums) scores = collection.query(prompt, _checksums) + return [scores.get(f.get_checksum(), 0) for f in features] diff --git a/mentat/parsers/parser.py b/mentat/parsers/parser.py index e911a6218..eebe02d7a 100644 --- a/mentat/parsers/parser.py +++ b/mentat/parsers/parser.py @@ -383,3 +383,6 @@ async def parse_llm_response(self, response: str) -> ParsedLLMResponse: parsed_response = await self.stream_and_parse_llm_response(async_iter_response) self._silence_printer = False return parsed_response + + def file_edits_to_llm_message(self, parsedLLMResponse: ParsedLLMResponse) -> str: + raise NotImplementedError() diff --git a/mentat/sampler/CHANGELOG.md b/mentat/sampler/CHANGELOG.md index 1df0ee131..6678ecacb 100644 --- a/mentat/sampler/CHANGELOG.md +++ b/mentat/sampler/CHANGELOG.md @@ -3,7 +3,9 @@ All notable changes to this project will be documented in this file. ## [2024-03-22]: Add fields to cover content of SWE-Bench -Fields in common are left with their current name. Missing fields (*) are added. +Fields in common are left with their current name and missing fields (*) are added +with one exception: test_command is changed to "FAIL_TO_PASS", which is just a +json list of test commands. Sampler SWE-Bench - title diff --git a/mentat/sampler/README.md b/mentat/sampler/README.md index c3562adef..f8f912312 100644 --- a/mentat/sampler/README.md +++ b/mentat/sampler/README.md @@ -30,9 +30,8 @@ A `Sample` captures interactions between a developer and any LLM Coding Assistan | message_edit | | `str` | plaintext response returned for sample edit | | diff_edit | * | `str` | between starting (diff_head) and ending code. | | test_patch | | `str` | A patch to files used to evaluate the samples -| test_command | | `str` | discrete pass/fail, e.g. ‘pytest -k diff_active’ | -| PASS_TO_PASS | | `str` | discrete pass/fail, expected to pass - +| FAIL_TO_PASS | | `str` | A json list of test commands resolved by diff_edit | +| PASS_TO_PASS | | `str` | A json list of test commands that pass before and after | | version | | `str` | current Sample API version | Notes: diff --git a/mentat/sampler/sample.py b/mentat/sampler/sample.py index f74a617fb..5a21d06f0 100644 --- a/mentat/sampler/sample.py +++ b/mentat/sampler/sample.py @@ -29,7 +29,7 @@ class Sample: context: list[str] = attr.field(default=[]) # type: ignore diff_edit: str = attr.field(default="") test_patch: str = attr.field(default="") - test_command: str = attr.field(default="") + FAIL_TO_PASS: str = attr.field(default="") PASS_TO_PASS: str = attr.field(default="") version: str = attr.field(default=__version__) @@ -51,8 +51,14 @@ def load(cls, fname: str | Path) -> Sample: kwargs["environment_setup_commit"] = "" kwargs["hint_text"] = "" kwargs["test_patch"] = "" + if "test_command" in kwargs: + kwargs["FAIL_TO_PASS"] = json.dumps([kwargs["test_command"]]) + del kwargs["test_command"] + else: + kwargs["FAIL_TO_PASS"] = "" kwargs["PASS_TO_PASS"] = "" kwargs["version"] = "0.3.0" + _version = kwargs["version"] if _version != __version__: raise SampleError( f"Warning: sample version ({_version}) does not match current" f" version ({__version__})." @@ -101,10 +107,6 @@ def from_swe_bench(cls, benchmark: dict[str, str]) -> Sample: context=edited_files, diff_edit=patch, test_patch=benchmark.get("test_patch", ""), - test_command=( - "" if not benchmark.get("FAIL_TO_PASS") else "pytest " + " ".join(json.loads(benchmark["FAIL_TO_PASS"])) - ), - PASS_TO_PASS=( - "" if not benchmark.get("PASS_TO_PASS") else "pytest " + " ".join(json.loads(benchmark["PASS_TO_PASS"])) - ), + FAIL_TO_PASS=benchmark.get("FAIL_TO_PASS", ""), + PASS_TO_PASS=benchmark.get("PASS_TO_PASS", ""), ) diff --git a/mentat/sampler/sampler.py b/mentat/sampler/sampler.py index afcaa8cc1..acf381ef6 100644 --- a/mentat/sampler/sampler.py +++ b/mentat/sampler/sampler.py @@ -1,3 +1,4 @@ +import json import subprocess from pathlib import Path from uuid import uuid4 @@ -200,7 +201,7 @@ def _rp(f: str | Path) -> str: message_edit=message_edit, context=list(context), diff_edit=diff_edit, - test_command=test_command, + FAIL_TO_PASS=json.dumps([test_command]), ) # Save the hexsha and id diff --git a/tests/sampler_test.py b/tests/sampler_test.py index 89e390b41..f9f84a3c9 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -1,3 +1,4 @@ +import json import re from pathlib import Path from textwrap import dedent @@ -93,7 +94,7 @@ async def test_sample_from_context( assert sample.context == ["multifile_calculator/operations.py"] assert sample.diff_edit == "" assert sample.id != "" - assert sample.test_command == "test_test_command" + assert sample.FAIL_TO_PASS == json.dumps(["test_test_command"]) assert sample.version == __version__ @@ -170,7 +171,7 @@ async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_l assert "+# forty two" in edits[0] assert "test_file.py" in edits[1] assert "+# forty two" in edits[1] - assert sample.test_command == "test_test_command" + assert sample.FAIL_TO_PASS == json.dumps(["test_test_command"]) assert sample.version == "0.3.0" @@ -196,8 +197,8 @@ async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_l ' hashlib.sha1(data.encode("utf-8")).hexdigest()\n+\n+\n async def' " run_subprocess_async(*args: str) -> str:\n" ), - "test_command": "", - "version": "0.1.0", + "FAIL_TO_PASS": "", + "version": "0.3.0", } @@ -406,7 +407,7 @@ async def test_sampler_integration(temp_testbed, mock_session_context, mock_call await client.call_mentat("test_url") await client.call_mentat("test_title") await client.call_mentat("test_description") - await client.call_mentat("test_test_command") + await client.call_mentat("") await client.call_mentat("q") await client.shutdown()