Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/autopilot_loop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,29 @@ def success(self):
return self.exit_code == 0


def _stream_and_capture(pipe, echo_to, captured):
def _stream_and_capture(pipe, echo_to, captured, max_bytes=0):
"""Read from pipe line-by-line, echo to a stream, and capture all output."""
truncated = False
try:
for line in iter(pipe.readline, b""):
decoded = line.decode("utf-8", errors="replace")
captured.write(decoded)
if max_bytes > 0 and captured.tell() >= max_bytes:
if not truncated:
captured.write(
"\n[OUTPUT TRUNCATED at %d bytes]\n" % max_bytes
)
truncated = True
else:
captured.write(decoded)
if echo_to:
echo_to.write(decoded)
echo_to.flush()
finally:
pipe.close()


def run_agent(prompt, session_dir, model="claude-opus-4.6", timeout=1800, extra_flags=None):
def run_agent(prompt, session_dir, model="claude-opus-4.6", timeout=1800,
extra_flags=None, max_output_bytes=0):
"""Run copilot CLI in non-interactive mode.

Streams stdout to the terminal in real-time so you can see progress
Expand All @@ -60,6 +69,7 @@ def run_agent(prompt, session_dir, model="claude-opus-4.6", timeout=1800, extra_
model: Model name for --model flag.
timeout: Timeout in seconds (SIGTERM, then SIGKILL after 30s grace).
extra_flags: Additional CLI flags as a list of strings.
max_output_bytes: Max bytes to capture in memory (0 = unlimited).

Returns:
AgentResult with exit code, session file path, stdout, stderr, duration.
Expand Down Expand Up @@ -109,12 +119,12 @@ def run_agent(prompt, session_dir, model="claude-opus-4.6", timeout=1800, extra_

stdout_thread = threading.Thread(
target=_stream_and_capture,
args=(proc.stdout, sys.stdout, stdout_captured),
args=(proc.stdout, sys.stdout, stdout_captured, max_output_bytes),
daemon=True,
)
stderr_thread = threading.Thread(
target=_stream_and_capture,
args=(proc.stderr, None, stderr_captured),
args=(proc.stderr, None, stderr_captured, max_output_bytes),
daemon=True,
)
stdout_thread.start()
Expand Down
17 changes: 13 additions & 4 deletions src/autopilot_loop/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,13 @@ def cmd_start(args):
sys.exit(1)

if args.issue:
from autopilot_loop.github_api import get_issue
from autopilot_loop.github_api import GitHubAPIError, get_issue
issue_number, repo = _parse_issue_arg(args.issue)
issue = get_issue(issue_number, repo=repo)
try:
issue = get_issue(issue_number, repo=repo)
except GitHubAPIError as exc:
print("Error: could not fetch issue: %s" % exc, file=sys.stderr)
sys.exit(1)
if repo:
prompt = "Issue %s#%d: %s\n\n%s" % (repo, issue_number, issue["title"], issue["body"][:4000])
else:
Expand All @@ -202,8 +206,13 @@ def cmd_start(args):
if not os.path.isfile(prompt_file):
print("Error: file not found: %s" % prompt_file, file=sys.stderr)
sys.exit(1)
with open(prompt_file, "r") as f:
prompt = f.read().strip()
try:
with open(prompt_file, "r") as f:
prompt = f.read().strip()
except (PermissionError, OSError) as exc:
print("Error: could not read file %s: %s" % (prompt_file, exc),
file=sys.stderr)
sys.exit(1)
if not prompt:
print("Error: file is empty: %s" % prompt_file, file=sys.stderr)
sys.exit(1)
Expand Down
10 changes: 9 additions & 1 deletion src/autopilot_loop/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"ci_check_names": [],
"ci_poll_interval_seconds": 120,
"ci_poll_timeout_seconds": 5400,
"max_output_bytes": 52428800,
}

CONFIG_FILENAMES = [
Expand All @@ -53,7 +54,12 @@ def load_config(cli_overrides=None):
if os.path.isfile(path):
logger.info("Loading config from %s", path)
with open(path, "r") as f:
file_config = json.load(f)
try:
file_config = json.load(f)
except json.JSONDecodeError as exc:
raise ValueError(
"Invalid JSON in config file %s: %s" % (path, exc)
)
config.update(file_config)
break
else:
Expand Down Expand Up @@ -83,3 +89,5 @@ def _validate(config):
raise ValueError("agent_timeout_seconds must be >= 60, got %d" % config["agent_timeout_seconds"])
if "{task_id}" not in config["branch_pattern"]:
raise ValueError("branch_pattern must contain {task_id}, got '%s'" % config["branch_pattern"])
if config["max_output_bytes"] < 1048576:
raise ValueError("max_output_bytes must be >= 1048576 (1 MB), got %d" % config["max_output_bytes"])
14 changes: 12 additions & 2 deletions src/autopilot_loop/github_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ def get_issue(issue_number, repo=None):
if repo:
cmd.extend(["--repo", repo])
output = _run_gh(cmd)
return json.loads(output)
try:
return json.loads(output)
except json.JSONDecodeError as exc:
raise GitHubAPIError(
"Failed to parse issue response: %s" % exc
)


def get_pr_description(pr_number):
Expand All @@ -199,7 +204,12 @@ def get_pr_description(pr_number):
"pr", "view", str(pr_number),
"--json", "title,body",
])
return json.loads(output)
try:
return json.loads(output)
except json.JSONDecodeError as exc:
raise GitHubAPIError(
"Failed to parse PR description response: %s" % exc
)


def update_pr_description(pr_number, body):
Expand Down
7 changes: 5 additions & 2 deletions src/autopilot_loop/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _run_agent_with_retry(self, phase, prompt, session_name):
model=self.config.get("model", "claude-opus-4.6"),
timeout=self.config.get("agent_timeout_seconds", 1800),
extra_flags=self._get_extra_flags(),
max_output_bytes=self.config.get("max_output_bytes", 0),
)
ended_at = time.time()

Expand Down Expand Up @@ -252,7 +253,7 @@ def _run_agent_with_retry(self, phase, prompt, session_name):
def _do_verify_push(self):
"""Verify new commits were pushed after fix."""
branch = self.task["branch"]
pre_sha = getattr(self, "_pre_fix_sha", None)
pre_sha = self.task.get("pre_fix_sha") or getattr(self, "_pre_fix_sha", None)

if pre_sha and verify_new_commits(branch, pre_sha):
logger.info("[%s] \u2713 New commits found on %s", self.task_id, branch)
Expand Down Expand Up @@ -737,8 +738,9 @@ def _do_fix(self):
task_context=self.task.get("prompt", ""),
)

# Record head SHA before fix
# Record head SHA before fix — persisted for crash recovery
self._pre_fix_sha = get_head_sha(self.task["branch"])
update_task(self.task_id, pre_fix_sha=self._pre_fix_sha)

result = self._run_agent_with_retry("FIX", prompt, "fix-%d" % iteration)
if result is None:
Expand Down Expand Up @@ -1003,6 +1005,7 @@ def _do_fix_ci(self):
)

self._pre_fix_sha = get_head_sha(self.task["branch"])
update_task(self.task_id, pre_fix_sha=self._pre_fix_sha)

result = self._run_agent_with_retry("FIX_CI", prompt, "fix-ci-%d" % iteration)
if result is None:
Expand Down
6 changes: 4 additions & 2 deletions src/autopilot_loop/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# Bump this when the schema changes. Additive changes (new nullable columns)
# are handled by _migrate(). Breaking changes trigger a DB recreate.
SCHEMA_VERSION = 6
SCHEMA_VERSION = 7

SCHEMA = """
CREATE TABLE IF NOT EXISTS schema_meta (
Expand All @@ -64,6 +64,7 @@
existing_branch INTEGER NOT NULL DEFAULT 0,
original_idle_timeout INTEGER,
prompt_file TEXT,
pre_fix_sha TEXT,
created_at REAL NOT NULL,
updated_at REAL NOT NULL
);
Expand Down Expand Up @@ -99,6 +100,7 @@
(4, "tasks", "existing_branch", "INTEGER NOT NULL DEFAULT 0"),
(5, "tasks", "original_idle_timeout", "INTEGER"),
(6, "tasks", "prompt_file", "TEXT"),
(7, "tasks", "pre_fix_sha", "TEXT"),
]


Expand Down Expand Up @@ -202,7 +204,7 @@ def get_task(task_id):
"prompt", "state", "pr_number", "branch", "iteration",
"max_iterations", "plan_mode", "dry_run", "model", "last_review_id",
"task_mode", "ci_check_names", "pre_stop_state", "existing_branch",
"original_idle_timeout", "prompt_file", "updated_at",
"original_idle_timeout", "prompt_file", "pre_fix_sha", "updated_at",
})


Expand Down
35 changes: 35 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,38 @@ def test_session_file_path(self, session_dir):
result = run_agent("prompt", session_dir)

assert result.session_file == os.path.join(session_dir, "session.md")

def test_output_truncated_when_limit_exceeded(self, session_dir):
"""When max_output_bytes is set and exceeded, output is truncated."""
# Generate output larger than the limit
large_output = b"x" * 100 + b"\n" # 101 bytes per line
lines = large_output * 20 # ~2020 bytes total
mock_proc = _mock_proc(stdout=lines)

with patch("autopilot_loop.agent.subprocess.Popen", return_value=mock_proc):
result = run_agent("prompt", session_dir, max_output_bytes=500)

assert "[OUTPUT TRUNCATED at 500 bytes]" in result.stdout
# Captured output should be limited (truncation marker + some lines)
assert len(result.stdout) < 2020

def test_output_not_truncated_when_under_limit(self, session_dir):
"""When output is under max_output_bytes, nothing is truncated."""
mock_proc = _mock_proc(stdout=b"small output\n")

with patch("autopilot_loop.agent.subprocess.Popen", return_value=mock_proc):
result = run_agent("prompt", session_dir, max_output_bytes=50000)

assert "small output" in result.stdout
assert "TRUNCATED" not in result.stdout

def test_output_unlimited_when_zero(self, session_dir):
"""When max_output_bytes is 0 (default), output is not truncated."""
large_output = b"x" * 1000 + b"\n"
mock_proc = _mock_proc(stdout=large_output * 10)

with patch("autopilot_loop.agent.subprocess.Popen", return_value=mock_proc):
result = run_agent("prompt", session_dir, max_output_bytes=0)

assert "TRUNCATED" not in result.stdout
assert len(result.stdout) > 5000
41 changes: 41 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,4 +663,45 @@ def fake_get_issue(num, repo=None):
assert len(calls) == 1
assert calls[0] == (42, None)

def test_issue_fetch_failure_exits(self, monkeypatch, capsys):
"""--issue that fails to fetch prints error and exits 1."""
self._stub_start_deps(monkeypatch)

from autopilot_loop.github_api import GitHubAPIError

def fake_get_issue(num, repo=None):
raise GitHubAPIError("gh command failed (exit 1)")

monkeypatch.setattr("autopilot_loop.github_api.get_issue", fake_get_issue)

with pytest.raises(SystemExit) as exc_info:
cmd_start(self._make_args(issue="42"))

assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "could not fetch issue" in captured.err

def test_file_permission_error_exits(self, monkeypatch, capsys):
"""--file with unreadable file prints error and exits 1."""
self._stub_start_deps(monkeypatch)

import builtins
real_open = builtins.open

def fake_open(path, *a, **kw):
if "unreadable" in str(path):
raise PermissionError("Permission denied: %s" % path)
return real_open(path, *a, **kw)

monkeypatch.setattr(builtins, "open", fake_open)
# os.path.isfile must return True for the file
monkeypatch.setattr("os.path.isfile", lambda p: True)

with pytest.raises(SystemExit) as exc_info:
cmd_start(self._make_args(prompt=None, file="/tmp/unreadable.txt"))

assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "could not read file" in captured.err


25 changes: 25 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,28 @@ def test_validation_agent_timeout():
def test_validation_branch_pattern():
with pytest.raises(ValueError, match="branch_pattern"):
load_config({"branch_pattern": "no-task-id-placeholder"})


def test_validation_max_output_bytes():
with pytest.raises(ValueError, match="max_output_bytes"):
load_config({"max_output_bytes": 100})


def test_max_output_bytes_default():
config = load_config()
assert config["max_output_bytes"] == 52428800


def test_malformed_json_config_file(tmp_path, monkeypatch):
"""Malformed JSON config file raises ValueError with useful message."""
config_file = tmp_path / "autopilot.json"
config_file.write_text("{invalid json!!!")

import autopilot_loop.config as config_module
original = config_module.CONFIG_FILENAMES
config_module.CONFIG_FILENAMES = [str(config_file)]
try:
with pytest.raises(ValueError, match="Invalid JSON"):
load_config()
finally:
config_module.CONFIG_FILENAMES = original
14 changes: 14 additions & 0 deletions tests/test_github_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ def test_returns_issue(self):
result = get_issue(123)
assert result["title"] == "Bug in X"

def test_invalid_json_raises_api_error(self):
from autopilot_loop.github_api import GitHubAPIError
with patch("autopilot_loop.github_api.subprocess.run",
return_value=_mock_run("not valid json")):
with pytest.raises(GitHubAPIError, match="Failed to parse issue"):
get_issue(123)


class TestVerifyNewCommits:
@patch("autopilot_loop.github_api.get_head_sha")
Expand Down Expand Up @@ -595,6 +602,13 @@ def test_api_error_raises(self):
with pytest.raises(GitHubAPIError):
get_pr_description(999)

def test_invalid_json_raises_api_error(self):
from autopilot_loop.github_api import GitHubAPIError, get_pr_description
with patch("autopilot_loop.github_api.subprocess.run",
return_value=_mock_run("bad json data")):
with pytest.raises(GitHubAPIError, match="Failed to parse PR description"):
get_pr_description(42)


class TestUpdatePrDescription:
def test_calls_gh_pr_edit(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def test_migration_from_pre_versioned_db(tmp_path, monkeypatch):
assert task["existing_branch"] == 0
assert task["original_idle_timeout"] is None
assert task["prompt_file"] is None
assert task["pre_fix_sha"] is None

# New columns should be usable
persistence.update_task("old1", task_mode="ci", ci_check_names='["check-a"]')
Expand Down Expand Up @@ -253,3 +254,17 @@ def test_prompt_file_persists(tmp_path, monkeypatch):
persistence.update_task("t1", prompt_file="/tmp/my-task.txt")
task = persistence.get_task("t1")
assert task["prompt_file"] == "/tmp/my-task.txt"


def test_pre_fix_sha_persists(tmp_path, monkeypatch):
"""pre_fix_sha column can be written and read back."""
monkeypatch.setattr(persistence, "DB_DIR", str(tmp_path))
monkeypatch.setattr(persistence, "DB_PATH", str(tmp_path / "state.db"))

persistence.create_task("t1", "prompt")
task = persistence.get_task("t1")
assert task["pre_fix_sha"] is None

persistence.update_task("t1", pre_fix_sha="abc123def456")
task = persistence.get_task("t1")
assert task["pre_fix_sha"] == "abc123def456"
Loading