Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions bin/lib/cli/blue_green.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
elb_client,
get_current_key,
get_releases,
get_ssm_param,
release_for,
)
from lib.aws_utils import get_asg_info, reset_asg_min_size, scale_asg
Expand All @@ -20,6 +19,7 @@
from lib.ce_utils import are_you_sure, display_releases
from lib.cli import cli
from lib.env import BLUE_GREEN_ENABLED_ENVIRONMENTS, Config, Environment
from lib.github_app import get_github_app_token
from lib.notify import handle_notify


Expand Down Expand Up @@ -310,19 +310,8 @@ def blue_green_deploy(

# Show what would be notified
print("\n🔍 Checking what would be notified...")
try:
gh_token = get_ssm_param("/compiler-explorer/githubAuthToken")
handle_notify(original_commit_hash, target_commit_hash, gh_token, dry_run=True)
except ClientError as e:
print(f"⚠️ Could not retrieve GitHub token ({e})")
print("🔍 Showing commit range that would be checked:")
print(" GitHub API would be queried for commits between:")
print(f" {original_commit_hash} (current deployment)")
print(f" {target_commit_hash} (target deployment)")
print(
f" URL: https://github.com/compiler-explorer/compiler-explorer/compare/{original_commit_hash[:8]}...{target_commit_hash[:8]}"
)
print(" Each commit's linked PRs and issues would be notified with 'This is now live' messages")
gh_token = get_github_app_token()
handle_notify(original_commit_hash, target_commit_hash, gh_token, dry_run=True)

return

Expand Down Expand Up @@ -382,7 +371,7 @@ def blue_green_deploy(
if should_notify and cfg.env == Environment.PROD:
if original_commit_hash is not None and target_commit_hash is not None:
try:
gh_token = get_ssm_param("/compiler-explorer/githubAuthToken")
gh_token = get_github_app_token()
print(f"\n{'[DRY RUN] ' if dry_run_notify else ''}Checking for notifications...")
print(
f"Checking commits from {original_commit_hash[:8]}...{original_commit_hash[-8:]} to {target_commit_hash[:8]}...{target_commit_hash[-8:]}"
Expand Down
144 changes: 144 additions & 0 deletions bin/lib/github_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""GitHub App authentication for Compiler Explorer Bot."""

from __future__ import annotations

import json
import logging
import time
import urllib.error
import urllib.request

import jwt

from lib.amazon import get_ssm_param, ssm_client

LOGGER = logging.getLogger(__name__)

GITHUB_API_URL = "https://api.github.com"
USER_AGENT = "CE GitHub App Auth"

# SSM parameter names
SSM_APP_ID = "/compiler-explorer/github-app-id"
SSM_PRIVATE_KEY = "/compiler-explorer/github-app-private-key"


def generate_jwt(app_id: str, private_key: str) -> str:
"""Generate a JWT for GitHub App authentication.

Args:
app_id: The GitHub App ID
private_key: The private key in PEM format

Returns:
A JWT token valid for 10 minutes
"""
now = int(time.time())
payload = {
"iat": now - 60, # Issued 60 seconds ago to account for clock drift
"exp": now + (10 * 60), # Expires in 10 minutes
"iss": app_id,
}
return jwt.encode(payload, private_key, algorithm="RS256")


def get_installation_id(app_jwt: str, org: str = "compiler-explorer") -> int:
"""Get the installation ID for a GitHub App on an organization.

Args:
app_jwt: JWT token for the GitHub App
org: The organization name to find the installation for

Returns:
The installation ID

Raises:
RuntimeError: If the installation is not found or API request fails
"""
try:
req = urllib.request.Request(
f"{GITHUB_API_URL}/app/installations",
headers={
"User-Agent": USER_AGENT,
"Authorization": f"Bearer {app_jwt}",
"Accept": "application/vnd.github.v3+json",
},
)
result = urllib.request.urlopen(req)
installations = json.loads(result.read())

for installation in installations:
if installation.get("account", {}).get("login") == org:
return installation["id"]

raise RuntimeError(f"GitHub App is not installed on organization '{org}'")
except (OSError, urllib.error.URLError, json.JSONDecodeError) as e:
raise RuntimeError(f"Failed to get GitHub App installations: {e}") from e


def get_installation_token(app_jwt: str, installation_id: int) -> str:
"""Get an installation access token for a GitHub App.

Args:
app_jwt: JWT token for the GitHub App
installation_id: The installation ID

Returns:
An installation access token valid for 1 hour

Raises:
RuntimeError: If the token request fails
"""
try:
req = urllib.request.Request(
f"{GITHUB_API_URL}/app/installations/{installation_id}/access_tokens",
data=b"",
method="POST",
headers={
"User-Agent": USER_AGENT,
"Authorization": f"Bearer {app_jwt}",
"Accept": "application/vnd.github.v3+json",
},
)
result = urllib.request.urlopen(req)
response = json.loads(result.read())
return response["token"]
except (OSError, urllib.error.URLError, json.JSONDecodeError) as e:
raise RuntimeError(f"Failed to get installation access token: {e}") from e


def get_github_app_token() -> str:
"""Get a GitHub installation access token using credentials from SSM.

This function:
1. Retrieves the App ID and private key from AWS SSM
2. Generates a JWT signed with the private key
3. Finds the installation ID for the compiler-explorer org
4. Exchanges the JWT for an installation access token

Returns:
An installation access token for the GitHub App

Raises:
RuntimeError: If credentials are missing or authentication fails
"""
LOGGER.debug("Retrieving GitHub App credentials from SSM")

try:
app_id = get_ssm_param(SSM_APP_ID)
except Exception as e:
raise RuntimeError(f"Failed to get GitHub App ID from SSM ({SSM_APP_ID}): {e}") from e

try:
# Private key is stored as SecureString, needs WithDecryption
private_key = ssm_client.get_parameter(Name=SSM_PRIVATE_KEY, WithDecryption=True)["Parameter"]["Value"]
except Exception as e:
raise RuntimeError(f"Failed to get GitHub App private key from SSM ({SSM_PRIVATE_KEY}): {e}") from e

LOGGER.debug("Generating JWT for GitHub App")
app_jwt = generate_jwt(app_id, private_key)

LOGGER.debug("Getting installation ID for compiler-explorer org")
installation_id = get_installation_id(app_jwt)

LOGGER.debug("Getting installation access token")
return get_installation_token(app_jwt, installation_id)
4 changes: 2 additions & 2 deletions bin/lib/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def post(entity: str, token: str, query: dict | None = None, dry_run=False) -> d
data=querystring,
headers={
"User-Agent": USER_AGENT,
"Authorization": f"token {token}",
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
},
)
Expand All @@ -55,7 +55,7 @@ def get(entity: str, token: str, query: dict | None = None) -> dict:
None,
{
"User-Agent": USER_AGENT,
"Authorization": f"token {token}",
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
},
)
Expand Down
162 changes: 162 additions & 0 deletions bin/test/github_app_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Tests for the github_app module."""

from __future__ import annotations

import json
from unittest.mock import MagicMock, patch

import jwt
import pytest
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from lib.github_app import (
generate_jwt,
get_github_app_token,
get_installation_id,
get_installation_token,
)


def generate_test_key_pair():
"""Generate a test RSA key pair."""
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend(),
)

pem_private = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode()

pem_public = (
private_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode()
)

return pem_private, pem_public


def test_generate_jwt_returns_string():
"""Test that generate_jwt returns a JWT string."""
test_private_key, _ = generate_test_key_pair()

jwt_token = generate_jwt("12345", test_private_key)

assert isinstance(jwt_token, str)
assert len(jwt_token.split(".")) == 3 # JWT has 3 parts separated by dots


def test_generate_jwt_contains_correct_claims():
"""Test that the generated JWT contains the correct claims."""
test_private_key, test_public_key = generate_test_key_pair()

jwt_token = generate_jwt("12345", test_private_key)

decoded = jwt.decode(jwt_token, test_public_key, algorithms=["RS256"])

assert decoded["iss"] == "12345"
assert "iat" in decoded
assert "exp" in decoded
# exp should be about 10 minutes after iat (with 60s clock drift adjustment)
assert abs((decoded["exp"] - decoded["iat"]) - 11 * 60) < 5


def test_get_installation_id_success():
"""Test successful installation ID retrieval."""
mock_response = MagicMock()
mock_response.read.return_value = json.dumps([
{"id": 111, "account": {"login": "other-org"}},
{"id": 222, "account": {"login": "compiler-explorer"}},
]).encode()

with patch("urllib.request.urlopen", return_value=mock_response):
result = get_installation_id("fake_jwt", org="compiler-explorer")

assert result == 222


def test_get_installation_id_not_found():
"""Test error when installation is not found."""
mock_response = MagicMock()
mock_response.read.return_value = json.dumps([
{"id": 111, "account": {"login": "other-org"}},
]).encode()

with patch("urllib.request.urlopen", return_value=mock_response):
with pytest.raises(RuntimeError, match="not installed"):
get_installation_id("fake_jwt", org="compiler-explorer")


def test_get_installation_token_success():
"""Test successful installation token retrieval."""
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"token": "ghs_xxxxxxxxxxxxxxxxxxxx",
"expires_at": "2024-01-01T00:00:00Z",
}).encode()

with patch("urllib.request.urlopen", return_value=mock_response):
result = get_installation_token("fake_jwt", 12345)

assert result == "ghs_xxxxxxxxxxxxxxxxxxxx"


def test_get_github_app_token_success():
"""Test successful end-to-end token retrieval."""
test_private_key, _ = generate_test_key_pair()

mock_ssm_client = MagicMock()
mock_ssm_client.get_parameter.return_value = {"Parameter": {"Value": test_private_key}}

mock_installations_response = MagicMock()
mock_installations_response.read.return_value = json.dumps([
{"id": 67890, "account": {"login": "compiler-explorer"}},
]).encode()

mock_token_response = MagicMock()
mock_token_response.read.return_value = json.dumps({
"token": "ghs_test_token_12345",
}).encode()

with (
patch("lib.github_app.get_ssm_param", return_value="12345"),
patch("lib.github_app.ssm_client", mock_ssm_client),
patch("urllib.request.urlopen", side_effect=[mock_installations_response, mock_token_response]),
):
result = get_github_app_token()

assert result == "ghs_test_token_12345"


def test_get_github_app_token_missing_app_id():
"""Test error when App ID is missing from SSM."""

def mock_get_ssm_param(param):
if "app-id" in param:
raise Exception("Parameter not found")
return "some_value"

with patch("lib.github_app.get_ssm_param", side_effect=mock_get_ssm_param):
with pytest.raises(RuntimeError, match="App ID"):
get_github_app_token()


def test_get_github_app_token_missing_private_key():
"""Test error when private key is missing from SSM."""
mock_ssm_client = MagicMock()
mock_ssm_client.get_parameter.side_effect = Exception("Parameter not found")

with (
patch("lib.github_app.get_ssm_param", return_value="12345"),
patch("lib.github_app.ssm_client", mock_ssm_client),
):
with pytest.raises(RuntimeError, match="private key"):
get_github_app_token()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"requests-cache>=1.2.1",
"matplotlib>=3.10.5",
"pillow>=11.3.0",
"PyJWT[crypto]>=2.8.0",
]

[dependency-groups]
Expand Down
Loading