Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
from .engine import EngineArgs


def _generate_api_key() -> str:
"""Return a secure, unique API key for the vLLM server.

Prefers the ``ART_API_KEY`` environment variable so operators can pin
a known credential. Falls back to a cryptographically random token
that is different for every server invocation.
"""
import os
import secrets

return os.environ.get("ART_API_KEY") or secrets.token_urlsafe(32)


def get_openai_server_config(
model_name: str,
base_model: str,
Expand All @@ -27,7 +40,7 @@ def get_openai_server_config(
lora_modules = [f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}']

server_args = ServerArgs(
api_key="default",
api_key=_generate_api_key(),
lora_modules=lora_modules,
return_tokens_as_token_ids=True,
enable_auto_tool_choice=True,
Expand Down
9 changes: 8 additions & 1 deletion src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,21 @@ async def _prepare_backend_for_training(
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
server_args["port"] = s.getsockname()[1]

# Ensure the server and client share the same API key.
# If the caller did not supply one, generate a secure random key
# so the vLLM server is never exposed with a well-known credential.
if not server_args.get("api_key"):
server_args["api_key"] = dev._generate_api_key()

config_dict["server_args"] = server_args
resolved_config = cast(dev.OpenAIServerConfig, config_dict)

service = await self._get_service(model)
host, port = await service.start_openai_server(config=resolved_config)

base_url = f"http://{host}:{port}/v1"
api_key = server_args.get("api_key") or "default"
api_key = server_args["api_key"]

def done_callback(_: asyncio.Task[None]) -> None:
close_proxy(self._services.pop(model.name))
Expand Down
2 changes: 1 addition & 1 deletion src/art/tinker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def start(self) -> tuple[str, int]:
for i in range(self.num_workers or self._default_num_workers())
]
self._task = asyncio.create_task(self._run(host, port))
client = AsyncOpenAI(api_key="default", base_url=f"http://{host}:{port}/v1")
client = AsyncOpenAI(api_key="health-check", base_url=f"http://{host}:{port}/v1")
start = time.time()
while True:
timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0))
Expand Down
2 changes: 1 addition & 1 deletion src/art/tinker_native/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def _prepare_backend_for_training(
port = server_args.get("port", raw_config.get("port"))
if port is None:
port = get_free_port()
api_key = server_args.get("api_key", raw_config.get("api_key")) or "default"
api_key = server_args.get("api_key", raw_config.get("api_key")) or dev._generate_api_key()

if state.server_task is None:
state.server_host = host
Expand Down
134 changes: 134 additions & 0 deletions tests/unit/test_api_key_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Unit tests for secure API key generation in vLLM server configuration.

Verifies that the hardcoded 'default' API key (CWE-798) has been replaced
with a cryptographically random token, and that the ART_API_KEY environment
variable override works correctly.

See: https://github.com/OpenPipe/ART/issues/628
"""

import os

import pytest

from art.dev.openai_server import (
OpenAIServerConfig,
_generate_api_key,
get_openai_server_config,
)


class TestGenerateApiKey:
"""Tests for the ``_generate_api_key`` helper."""

def test_key_is_not_hardcoded_default(self) -> None:
"""The generated key must never be the literal string 'default'."""
key = _generate_api_key()
assert key != "default"

def test_key_is_nonempty_string(self) -> None:
key = _generate_api_key()
assert isinstance(key, str)
assert len(key) > 0

def test_key_has_sufficient_entropy(self) -> None:
"""A 32-byte urlsafe token encodes to >= 32 characters."""
key = _generate_api_key()
assert len(key) >= 32

def test_keys_are_unique_across_calls(self) -> None:
"""Each invocation should produce a different random key."""
keys = {_generate_api_key() for _ in range(20)}
assert len(keys) == 20

def test_env_var_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ART_API_KEY environment variable should take precedence."""
monkeypatch.setenv("ART_API_KEY", "my-custom-key-42")
assert _generate_api_key() == "my-custom-key-42"

def test_env_var_empty_falls_back_to_random(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""An empty ART_API_KEY should be treated as unset."""
monkeypatch.setenv("ART_API_KEY", "")
key = _generate_api_key()
assert key != ""
assert key != "default"

def test_env_var_unset_falls_back_to_random(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When ART_API_KEY is not in the environment, use random key."""
monkeypatch.delenv("ART_API_KEY", raising=False)
key = _generate_api_key()
assert key != "default"
assert len(key) >= 32


class TestGetOpenaiServerConfig:
"""Tests for ``get_openai_server_config`` API key behaviour."""

def test_default_config_uses_random_key(self) -> None:
"""Without user-supplied config, the key must not be 'default'."""
config = get_openai_server_config(
model_name="test-model",
base_model="base-model",
log_file="/tmp/test.log",
)
api_key = config["server_args"]["api_key"]
assert api_key != "default"
assert isinstance(api_key, str)
assert len(api_key) >= 32

def test_user_supplied_key_takes_precedence(self) -> None:
"""A user-provided api_key in server_args should override the default."""
user_config = OpenAIServerConfig(
server_args={"api_key": "user-provided-key-xyz"} # type: ignore[typeddict-item]
)
config = get_openai_server_config(
model_name="test-model",
base_model="base-model",
log_file="/tmp/test.log",
config=user_config,
)
assert config["server_args"]["api_key"] == "user-provided-key-xyz"

def test_env_var_key_used_when_no_user_config(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""ART_API_KEY env var should be picked up when no user key is given."""
monkeypatch.setenv("ART_API_KEY", "env-key-123")
config = get_openai_server_config(
model_name="test-model",
base_model="base-model",
log_file="/tmp/test.log",
)
assert config["server_args"]["api_key"] == "env-key-123"

def test_user_key_overrides_env_var(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""User-supplied key should win over ART_API_KEY env var."""
monkeypatch.setenv("ART_API_KEY", "env-key-123")
user_config = OpenAIServerConfig(
server_args={"api_key": "user-wins"} # type: ignore[typeddict-item]
)
config = get_openai_server_config(
model_name="test-model",
base_model="base-model",
log_file="/tmp/test.log",
config=user_config,
)
assert config["server_args"]["api_key"] == "user-wins"

def test_each_config_call_gets_unique_key(self) -> None:
"""Successive calls should not share the same random key."""
keys = set()
for _ in range(10):
config = get_openai_server_config(
model_name="test-model",
base_model="base-model",
log_file="/tmp/test.log",
)
keys.add(config["server_args"]["api_key"])
assert len(keys) == 10