Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions src/autopilot_loop/github_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import re
import subprocess
import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,9 +40,34 @@ class GitHubAPIError(Exception):
pass


_TRANSIENT_PATTERNS = [
"rate limit",
"abuse detection",
"server error",
"502",
"503",
"504",
"timed out",
"connection refused",
"connection reset",
"network is unreachable",
]

_MAX_RETRIES = 3


def _is_transient(stderr):
"""Check if a gh CLI error looks transient (retryable)."""
lower = stderr.lower()
return any(p in lower for p in _TRANSIENT_PATTERNS)


def _run_gh(args, check=True):
"""Run a gh CLI command and return stdout.

Retries up to _MAX_RETRIES times with exponential backoff for
transient errors (rate limits, 5xx, network issues).

Args:
args: Command arguments as a list (without 'gh' prefix).
check: If True, raise GitHubAPIError on non-zero exit.
Expand All @@ -51,13 +77,35 @@ def _run_gh(args, check=True):
"""
cmd = ["gh"] + args
logger.debug("Running: %s", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True)
if check and result.returncode != 0:

last_result = None
for attempt in range(_MAX_RETRIES + 1):
result = subprocess.run(cmd, capture_output=True, text=True)
last_result = result

if result.returncode == 0:
return result.stdout.strip()

# Only retry on transient errors
if attempt < _MAX_RETRIES and _is_transient(result.stderr):
delay = 2 ** attempt # 1s, 2s, 4s
logger.warning(
"gh command failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1, _MAX_RETRIES + 1, delay,
result.stderr.strip()[:200],
)
time.sleep(delay)
continue

break

if check and last_result.returncode != 0:
raise GitHubAPIError(
"gh command failed (exit %d): %s\nstderr: %s"
% (result.returncode, " ".join(cmd), result.stderr.strip())
% (last_result.returncode, " ".join(cmd),
last_result.stderr.strip())
)
return result.stdout.strip()
return last_result.stdout.strip()


_nwo_cache = None
Expand Down Expand Up @@ -313,9 +361,12 @@ def get_unresolved_review_comments(pr_number):
# Check for GraphQL errors
if "errors" in data:
msgs = [e.get("message", str(e)) for e in data["errors"]]
logger.warning("GraphQL errors in get_unresolved_review_comments: %s", "; ".join(msgs))
msg = "; ".join(msgs)
if not data.get("data"):
return []
raise GitHubAPIError(
"GraphQL error in get_unresolved_review_comments: %s" % msg
)
logger.warning("GraphQL partial errors in get_unresolved_review_comments: %s", msg)

threads = (
data.get("data", {})
Expand Down Expand Up @@ -428,9 +479,12 @@ def get_latest_copilot_review_thread_ts(pr_number):

if "errors" in data:
msgs = [e.get("message", str(e)) for e in data["errors"]]
logger.warning("GraphQL errors in get_latest_copilot_review_thread_ts: %s", "; ".join(msgs))
msg = "; ".join(msgs)
if not data.get("data"):
return None
raise GitHubAPIError(
"GraphQL error in get_latest_copilot_review_thread_ts: %s" % msg
)
logger.warning("GraphQL partial errors in get_latest_copilot_review_thread_ts: %s", msg)

threads = (
data.get("data", {})
Expand Down
9 changes: 8 additions & 1 deletion src/autopilot_loop/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def __init__(self, task_id, config):
self.config = config
self.task = get_task(task_id)
self.sessions_dir = get_sessions_dir(task_id)
self._retry_counts = {} # phase -> retry count
# Restore retry counts from DB (survives crash/resume)
raw = self.task.get("retry_counts_json") or "{}"
try:
self._retry_counts = json.loads(raw)
except (ValueError, TypeError):
self._retry_counts = {}

def _get_handlers(self):
"""Return a dict mapping state names to handler methods. Subclasses must override."""
Expand Down Expand Up @@ -273,6 +278,7 @@ def _do_verify_push(self):

logger.warning("[%s] No new commits on %s, retrying fix", self.task_id, branch)
self._retry_counts[retry_key] = 1
update_task(self.task_id, retry_counts_json=json.dumps(self._retry_counts))
return self._retry_fix_state()

def _after_verify_push(self):
Expand Down Expand Up @@ -378,6 +384,7 @@ def _do_verify_pr(self):

logger.warning("[%s] No PR found for branch %s, retrying IMPLEMENT with explicit prompt", self.task_id, branch)
self._retry_counts[retry_key] = 1
update_task(self.task_id, retry_counts_json=json.dumps(self._retry_counts))

# Retry with a more explicit prompt
explicit_prompt = (
Expand Down
7 changes: 5 additions & 2 deletions src/autopilot_loop/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# Bump this when the schema changes. Additive changes (new nullable columns)
# are handled by _migrate(). Breaking changes trigger a DB recreate.
SCHEMA_VERSION = 7
SCHEMA_VERSION = 8

SCHEMA = """
CREATE TABLE IF NOT EXISTS schema_meta (
Expand All @@ -65,6 +65,7 @@
original_idle_timeout INTEGER,
prompt_file TEXT,
pre_fix_sha TEXT,
retry_counts_json TEXT,
created_at REAL NOT NULL,
updated_at REAL NOT NULL
);
Expand Down Expand Up @@ -101,6 +102,7 @@
(5, "tasks", "original_idle_timeout", "INTEGER"),
(6, "tasks", "prompt_file", "TEXT"),
(7, "tasks", "pre_fix_sha", "TEXT"),
(8, "tasks", "retry_counts_json", "TEXT"),
]


Expand Down Expand Up @@ -204,7 +206,8 @@ def get_task(task_id):
"prompt", "state", "pr_number", "branch", "iteration",
"max_iterations", "plan_mode", "dry_run", "model", "last_review_id",
"task_mode", "ci_check_names", "pre_stop_state", "existing_branch",
"original_idle_timeout", "prompt_file", "pre_fix_sha", "updated_at",
"original_idle_timeout", "prompt_file", "pre_fix_sha",
"retry_counts_json", "updated_at",
})


Expand Down
138 changes: 138 additions & 0 deletions tests/test_github_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,141 @@ def test_api_error_raises(self):
return_value=_mock_run("", returncode=1, stderr="error")):
with pytest.raises(GitHubAPIError):
update_pr_description(42, "body")


class TestRunGhRetry:
"""Tests for _run_gh retry logic on transient errors."""

def test_retries_on_rate_limit(self):
from autopilot_loop.github_api import _run_gh
calls = []

def fake_run(cmd, **kw):
calls.append(cmd)
if len(calls) < 3:
return _mock_run("", returncode=1, stderr="API rate limit exceeded")
return _mock_run("success")

with patch("autopilot_loop.github_api.subprocess.run", side_effect=fake_run):
with patch("autopilot_loop.github_api.time.sleep"):
result = _run_gh(["api", "test"])

assert result == "success"
assert len(calls) == 3

def test_no_retry_on_permanent_error(self):
from autopilot_loop.github_api import GitHubAPIError, _run_gh
calls = []

def fake_run(cmd, **kw):
calls.append(cmd)
return _mock_run("", returncode=1, stderr="Not Found (HTTP 404)")

with patch("autopilot_loop.github_api.subprocess.run", side_effect=fake_run):
with pytest.raises(GitHubAPIError):
_run_gh(["api", "test"])

assert len(calls) == 1

def test_retries_on_server_error(self):
from autopilot_loop.github_api import _run_gh
calls = []

def fake_run(cmd, **kw):
calls.append(cmd)
if len(calls) < 2:
return _mock_run("", returncode=1, stderr="502 Bad Gateway")
return _mock_run("ok")

with patch("autopilot_loop.github_api.subprocess.run", side_effect=fake_run):
with patch("autopilot_loop.github_api.time.sleep"):
result = _run_gh(["api", "test"])

assert result == "ok"
assert len(calls) == 2

def test_gives_up_after_max_retries(self):
from autopilot_loop.github_api import GitHubAPIError, _run_gh

def fake_run(cmd, **kw):
return _mock_run("", returncode=1, stderr="503 Service Unavailable")

with patch("autopilot_loop.github_api.subprocess.run", side_effect=fake_run):
with patch("autopilot_loop.github_api.time.sleep"):
with pytest.raises(GitHubAPIError):
_run_gh(["api", "test"])

def test_check_false_no_raise(self):
from autopilot_loop.github_api import _run_gh

def fake_run(cmd, **kw):
return _mock_run("", returncode=1, stderr="Not Found")

with patch("autopilot_loop.github_api.subprocess.run", side_effect=fake_run):
result = _run_gh(["api", "test"], check=False)

assert result == ""


class TestGraphQLErrorRaising:
"""Tests for GraphQL error raising in review comment functions."""

def test_unresolved_comments_raises_on_graphql_error_no_data(self):
from autopilot_loop.github_api import GitHubAPIError
error_response = {"errors": [{"message": "auth required"}]}
with patch("autopilot_loop.github_api._run_gh") as mock_gh:
mock_gh.side_effect = [
"octocat/hello-world",
json.dumps(error_response),
]
with pytest.raises(GitHubAPIError, match="auth required"):
get_unresolved_review_comments(42)

def test_unresolved_comments_partial_error_returns_data(self):
partial_response = {
"errors": [{"message": "some warning"}],
"data": {
"repository": {
"pullRequest": {
"reviewThreads": {"nodes": []}
}
}
},
}
with patch("autopilot_loop.github_api._run_gh") as mock_gh:
mock_gh.side_effect = [
"octocat/hello-world",
json.dumps(partial_response),
]
result = get_unresolved_review_comments(42)
assert result == []

def test_latest_thread_ts_raises_on_graphql_error_no_data(self):
from autopilot_loop.github_api import GitHubAPIError
error_response = {"errors": [{"message": "rate limited"}]}
with patch("autopilot_loop.github_api._run_gh") as mock_gh:
mock_gh.side_effect = [
"octocat/hello-world",
json.dumps(error_response),
]
with pytest.raises(GitHubAPIError, match="rate limited"):
get_latest_copilot_review_thread_ts(42)

def test_latest_thread_ts_partial_error_returns_data(self):
partial_response = {
"errors": [{"message": "some warning"}],
"data": {
"repository": {
"pullRequest": {
"reviewThreads": {"nodes": []}
}
}
},
}
with patch("autopilot_loop.github_api._run_gh") as mock_gh:
mock_gh.side_effect = [
"octocat/hello-world",
json.dumps(partial_response),
]
result = get_latest_copilot_review_thread_ts(42)
assert result is None
16 changes: 16 additions & 0 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def test_migration_from_pre_versioned_db(tmp_path, monkeypatch):
assert task["original_idle_timeout"] is None
assert task["prompt_file"] is None
assert task["pre_fix_sha"] is None
assert task["retry_counts_json"] is None

# New columns should be usable
persistence.update_task("old1", task_mode="ci", ci_check_names='["check-a"]')
Expand Down Expand Up @@ -268,3 +269,18 @@ def test_pre_fix_sha_persists(tmp_path, monkeypatch):
persistence.update_task("t1", pre_fix_sha="abc123def456")
task = persistence.get_task("t1")
assert task["pre_fix_sha"] == "abc123def456"


def test_retry_counts_json_persists(tmp_path, monkeypatch):
"""retry_counts_json column can be written and read back."""
monkeypatch.setattr(persistence, "DB_DIR", str(tmp_path))
monkeypatch.setattr(persistence, "DB_PATH", str(tmp_path / "state.db"))

persistence.create_task("t1", "prompt")
task = persistence.get_task("t1")
assert task["retry_counts_json"] is None

counts = '{"VERIFY_PUSH_FIX_RETRY": 1}'
persistence.update_task("t1", retry_counts_json=counts)
task = persistence.get_task("t1")
assert task["retry_counts_json"] == counts
Loading