Skip to content

Restrict git #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/system.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Install the project
run: uv sync
- name: Clone
run: uv run commit0 clone simpy
- name: Setup
- name: Set up commit0
run: uv run commit0 setup simpy
- name: Build docker images
run: uv run commit0 build simpy
- name: Get tests
run: uv run commit0 get-tests simpy
Expand Down
4 changes: 2 additions & 2 deletions commit0/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main() -> None:
# after hydra gets all configs, put command-line arguments back
sys.argv = sys_argv
# repo_split: split from command line has a higher priority than split in hydra
if command in ["clone", "build", "evaluate", "evaluate-reference", "save"]:
if command in ["setup", "build", "evaluate", "evaluate-reference", "save"]:
if len(sys.argv) >= 3:
if sys.argv[2] not in SPLIT:
raise ValueError(
Expand All @@ -39,7 +39,7 @@ def main() -> None:
config.repo_split = sys.argv[2]
config.base_dir = os.path.abspath(config.base_dir)

if command == "clone":
if command == "setup":
commit0.harness.setup.main(
config.dataset_name,
config.dataset_split,
Expand Down
2 changes: 1 addition & 1 deletion commit0/harness/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class RepoInstance(TypedDict):

# available commands
COMMANDS = [
"clone",
"setup",
"build",
"test",
"test-reference",
Expand Down
16 changes: 10 additions & 6 deletions commit0/harness/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import traceback
import pwd
from pathlib import Path
from io import BytesIO
from typing import Optional, List, Union
Expand Down Expand Up @@ -158,23 +159,26 @@ def copy_ssh_pubkey_from_container(container: Container) -> None:
if exit_code != 0:
raise Exception(f"Error reading file: {output.decode('utf-8').strip()}")
public_key = output.decode("utf-8").strip()
public_key = f"no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty {public_key}"

local_authorized_keys_path = os.path.expanduser("~/.ssh/authorized_keys")
os.makedirs(os.path.dirname(local_authorized_keys_path), exist_ok=True)
if not os.path.exists(local_authorized_keys_path):
user_info = pwd.getpwnam("git")
home_directory = user_info.pw_dir
authorized_keys_path = os.path.join(home_directory, ".ssh", "authorized_keys")
os.makedirs(os.path.dirname(authorized_keys_path), exist_ok=True)
if not os.path.exists(authorized_keys_path):
# Since the file does not exist, create it
open(local_authorized_keys_path, "a").close()
open(authorized_keys_path, "a").close()
write = True
else:
with open(local_authorized_keys_path, "r") as authorized_keys_file:
with open(authorized_keys_path, "r") as authorized_keys_file:
content = authorized_keys_file.read()
if public_key not in content:
write = True
else:
write = False

if write:
with open(local_authorized_keys_path, "a") as authorized_keys_file:
with open(authorized_keys_path, "a") as authorized_keys_file:
authorized_keys_file.write(public_key + "\n")

except docker.errors.APIError as e:
Expand Down
5 changes: 3 additions & 2 deletions commit0/harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,19 @@ def main(
for name in tqdm(log_dirs):
report_file = os.path.join(name, "report.json")
name = name.split("/")[2]
test_ids = get_tests(name, stdout=False)
if not os.path.exists(report_file):
out.append(
{
"name": name,
"sum": 0,
"passed": 0,
"num_passed": 0,
"num_tests": len(test_ids),
}
)
continue
report = load_dataset("json", data_files=report_file, split="train") # type: ignore
test_ids = get_tests(name, stdout=False)
tests = {x["nodeid"]: x["call"] for x in report["tests"][0]} # type: ignore
status = []
runtimes = []
Expand All @@ -110,7 +111,7 @@ def main(
"sum": total,
"passed": passed,
"num_passed": status["passed"] + status["xfail"],
"num_tests": sum(status.values()),
"num_tests": len(test_ids),
}
)
print("repo,runtime,num_passed/num_tests")
Expand Down
19 changes: 10 additions & 9 deletions commit0/harness/run_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
extract_test_output,
get_hash_string,
get_ip,
get_user,
)
from commit0.harness.execution_context import (
Docker,
Expand Down Expand Up @@ -74,7 +73,6 @@ def main(
commit_id=commit_id,
test_ids=test_ids,
ip=get_ip(backend),
user=get_user(),
)
eval_file = Path(log_dir / "eval.sh")
eval_file.write_text(eval_script)
Expand All @@ -96,18 +94,21 @@ def main(
output, "--json-report --json-report-file=report.json"
)
context.write_test_output(test_output, timed_out)
print(test_output)
except EvaluationError as e:
error_msg = traceback.format_exc()
logger.info(error_msg)
print(e)
error_msg = (
f"Error in running pytest for {repo}: {e}\n"
f"{traceback.format_exc()}\n"
f"Check ({log_file}) for more information."
)
raise EvaluationError(repo, error_msg, logger)
except Exception as e:
error_msg = (
f"Error in running pytest for {spec.repo}: {e}\n"
f"General error: {e}\n"
f"{traceback.format_exc()}\n"
# f"Check ({logger.log_file}) for more information."
f"Check ({log_file}) for more information."
)
logger.error(error_msg)

raise RuntimeError(error_msg)
return str(log_dir)


Expand Down
9 changes: 8 additions & 1 deletion commit0/harness/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from datasets import load_dataset

from typing import Iterator
from commit0.harness.utils import clone_repo, create_branch
from commit0.harness.utils import (
clone_repo,
create_branch,
setup_git,
add_safe_directory,
)
from commit0.harness.constants import RepoInstance, SPLIT


Expand All @@ -18,6 +23,7 @@ def main(
dataset_name: str, dataset_split: str, repo_split: str, base_dir: str, branch: str
) -> None:
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
setup_git(logger)
for example in dataset:
repo_name = example["repo"].split("/")[-1]
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
Expand All @@ -26,6 +32,7 @@ def main(
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
local_repo = clone_repo(clone_url, clone_dir, example["base_commit"], logger)
create_branch(local_repo, branch, logger)
add_safe_directory(clone_dir, logger)


__all__ = []
2 changes: 1 addition & 1 deletion commit0/harness/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_eval_script_list(instance: RepoInstance, repo_directory: str) -> list[s
"ssh-keyscan {ip} >> ~/.ssh/known_hosts",
f"cd {repo_directory}",
"source .venv/bin/activate",
f"git remote add {origin_name} ssh://{{user}}@{{ip}}:{{local_repo}}",
f"git remote add {origin_name} ssh://git@{{ip}}:{{local_repo}}",
f"git fetch {origin_name}",
"git checkout {commit_id}",
"git status",
Expand Down
97 changes: 93 additions & 4 deletions commit0/harness/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import getpass
import git
import git.exc
import hashlib
Expand All @@ -7,7 +6,8 @@
import os
import time
import requests
from typing import Optional
import subprocess
from typing import Optional, Tuple

from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
from ghapi.core import GhApi
Expand Down Expand Up @@ -58,8 +58,97 @@ def get_ip(backend: str) -> str:
return ip


def get_user() -> str:
return getpass.getuser()
def run_command(command: str) -> Tuple[str, str, int]:
"""Runs a shell command and returns the output, error message, and exit code."""
try:
result = subprocess.run(
command,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return (
result.stdout.decode("utf-8"),
result.stderr.decode("utf-8"),
result.returncode,
)
except subprocess.CalledProcessError as e:
return e.stdout.decode("utf-8"), e.stderr.decode("utf-8"), e.returncode


def handle_command(command: str, description: str, logger: logging.Logger) -> None:
"""Runs a command and handles success or failure with appropriate messages."""
stdout, stderr, exit_code = run_command(command)
if exit_code != 0:
logger.error(f"Error running '{command}' which {description}:\n{stderr}")
else:
logger.info(f"Succeeded in running '{command}' which {description}")


def setup_git(logger: logging.Logger) -> None:
"""Sets up the 'git' user with appropriate shell settings, .ssh directory, and git-shell as login shell."""
handle_command(
'sudo adduser --disabled-password --gecos "" git', "adds git user", logger
)

# Get git user's home directory dynamically
git_home_command = "getent passwd git | cut -d: -f6"
stdout, stderr, exit_code = run_command(git_home_command)
if exit_code != 0:
raise RuntimeError(f"Error getting git user's home directory: {stderr}")
git_home = stdout.strip() # Extract and trim the home directory

# Commands to be executed
commands = [
(f"sudo chmod 755 {git_home}", "make home of git viewable by others"),
(
f"sudo sh -c 'mkdir -p {git_home}/.ssh && chmod 755 {git_home}/.ssh && touch {git_home}/.ssh/authorized_keys && chmod 666 {git_home}/.ssh/authorized_keys'",
"sets up .ssh directory for git",
),
("sudo touch /etc/shells", "creates /etc/shells if it doesn't exist yet"),
("cat /etc/shells", "views available shells"),
(
"sudo sh -c 'which git-shell >> /etc/shells'",
"adds git-shell to /etc/shells",
),
(
"sudo chsh git -s $(which git-shell)",
"changes shell for git user to git-shell",
),
]

# Execute each command
for command, description in commands:
handle_command(command, description, logger)


def is_safe_directory_added(safe_directory: str) -> bool:
# Run command to get all safe directories
command = "sudo git config --system --get-all safe.directory"
stdout, stderr, exit_code = run_command(command)

# Check if the directory is listed
if exit_code == 0 and safe_directory in stdout.splitlines():
return True
else:
return False


def add_safe_directory(safe_directory: str, logger: logging.Logger) -> None:
safe_directory = os.path.join(safe_directory, ".git")
# Check if the directory is already added
if not is_safe_directory_added(safe_directory):
# Command to add the directory to safe.directory
command = f"sudo git config --system --add safe.directory {safe_directory}"
stdout, stderr, exit_code = run_command(command)

if exit_code == 0:
logger.info(f"Directory '{safe_directory}' added to safe.directory.")
else:
logger.error(f"Error adding directory: {stderr}")
else:
logger.info(f"Directory '{safe_directory}' is already in the list.")


def get_hash_string(input_string: str) -> str:
Expand Down
Loading