Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ jobs:
- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ inputs.python-version || '3.12' }}
- run: uv run ruff format --check src/celeste tests/
- run: |
if [ -d "packages" ]; then
uv run ruff format --check src/celeste tests/ packages/
else
uv run ruff format --check src/celeste tests/
fi

lint:
runs-on: ubuntu-latest
Expand All @@ -41,7 +46,12 @@ jobs:
- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ inputs.python-version || '3.12' }}
- run: uv run ruff check --output-format=github src/celeste tests/
- run: |
if [ -d "packages" ]; then
uv run ruff check --output-format=github src/celeste tests/ packages/
else
uv run ruff check --output-format=github src/celeste tests/
fi

type-check:
runs-on: ubuntu-latest
Expand All @@ -52,7 +62,12 @@ jobs:
- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ inputs.python-version || '3.12' }}
- run: uv run mypy -p celeste && uv run mypy tests/
- run: |
if [ -d "packages" ]; then
uv run mypy -p celeste && uv run mypy tests/ && uv run mypy packages/
else
uv run mypy -p celeste && uv run mypy tests/
fi

security:
runs-on: ubuntu-latest
Expand All @@ -63,7 +78,12 @@ jobs:
- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ inputs.python-version || '3.12' }}
- run: uv run bandit -c pyproject.toml -r src/ -f screen
- run: |
if [ -d "packages" ]; then
uv run bandit -c pyproject.toml -r src/ packages/ -f screen
else
uv run bandit -c pyproject.toml -r src/ -f screen
fi

test:
if: ${{ !inputs.skip-tests }}
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/claude-code-review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ jobs:
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'

1 change: 0 additions & 1 deletion .github/workflows/claude.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,3 @@ jobs:
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://docs.claude.com/en/docs/claude-code/cli-reference for available options
# claude_args: '--allowed-tools Bash(gh pr:*)'

10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@ sync:

# Linting
lint:
uv run ruff check src/celeste tests/
uv run ruff check src/celeste tests/ packages/

# Linting with auto-fix
lint-fix:
uv run ruff check --fix src/celeste tests/
uv run ruff check --fix src/celeste tests/ packages/

# Formatting
format:
uv run ruff format src/celeste tests/
uv run ruff format src/celeste tests/ packages/

# Type checking (fail fast on any error)
typecheck:
@uv run mypy -p celeste && uv run mypy tests/
@uv run mypy -p celeste && uv run mypy tests/ && uv run mypy packages/

# Testing
test:
uv run pytest tests/ --cov=celeste --cov-report=term-missing --cov-fail-under=90

# Security scanning (config reads from pyproject.toml)
security:
uv run bandit -c pyproject.toml -r src/ -f screen
uv run bandit -c pyproject.toml -r src/ packages/ -f screen

# Full CI/CD pipeline - what GitHub Actions will run
ci:
Expand Down
15 changes: 10 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ dev = [
"pre-commit>=3.5.0",
]

[project.entry-points."celeste.packages"]
# Example entry points for packages to register their models and clients:
# text_generation = "text_generation:register_models"
# image_generation = "image_generation:register_models"

[tool.uv.workspace]
members = ["packages/*"]

Expand Down Expand Up @@ -133,6 +128,7 @@ strict_equality = true
module = "tests.*"
disallow_untyped_defs = false # Relax for tests
disallow_incomplete_defs = false
disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"]

[[tool.mypy.overrides]]
module = "httpx"
Expand All @@ -142,6 +138,15 @@ ignore_missing_imports = true
module = "httpx_sse"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"celeste_text_generation.*",
"celeste_text_generation.client",
"celeste_text_generation.streaming",
"celeste_text_generation.providers.*",
]
disable_error_code = ["override", "return-value", "arg-type", "call-arg", "assignment", "no-any-return"]

[tool.bandit]
exclude_dirs = [".venv", "__pycache__"]
skips = ["B101"] # Skip B101 (assert_used) since we use pytest
Expand Down
54 changes: 40 additions & 14 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from celeste.streaming import Stream


class Client[In: Input, Out: Output](ABC, BaseModel):
class Client[In: Input, Out: Output, Params: Parameters](ABC, BaseModel):
"""Base class for all capability-specific clients."""

model_config = ConfigDict(from_attributes=True)
Expand All @@ -38,7 +38,11 @@ def http_client(self) -> HTTPClient:
"""Shared HTTP client with connection pooling for this provider."""
return get_http_client(self.provider, self.capability)

async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401
async def generate(
self,
*args: Any, # noqa: ANN401
**parameters: Unpack[Params], # type: ignore[misc]
) -> Out:
"""Generate content - signature varies by capability.

Args:
Expand All @@ -59,7 +63,11 @@ async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out:
metadata=self._build_metadata(response_data),
)

def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401
def stream(
self,
*args: Any, # noqa: ANN401
**parameters: Unpack[Params], # type: ignore[misc]
) -> Stream[Out, Params]:
"""Stream content - signature varies by capability.

Args:
Expand All @@ -79,7 +87,7 @@ def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]:
inputs = self._create_inputs(*args, **parameters)
request_body = self._build_request(inputs, **parameters)
sse_iterator = self._make_stream_request(request_body, **parameters)
return self._stream_class()( # type: ignore[call-arg]
return self._stream_class()(
sse_iterator,
transform_output=self._transform_output,
**parameters,
Expand All @@ -103,13 +111,19 @@ def _parse_usage(self, response_data: dict[str, Any]) -> Usage:

@abstractmethod
def _parse_content(
self, response_data: dict[str, Any], **parameters: Unpack[Parameters]
self,
response_data: dict[str, Any],
**parameters: Unpack[Params], # type: ignore[misc]
) -> object:
"""Parse content from provider response."""
...

@abstractmethod
def _create_inputs(self, *args: Any, **parameters: Unpack[Parameters]) -> In: # noqa: ANN401
def _create_inputs(
self,
*args: Any, # noqa: ANN401
**parameters: Unpack[Params], # type: ignore[misc]
) -> In:
"""Map positional arguments to Input type."""
...

Expand All @@ -121,19 +135,23 @@ def _output_class(cls) -> type[Out]:

@abstractmethod
async def _make_request(
self, request_body: dict[str, Any], **parameters: Unpack[Parameters]
self,
request_body: dict[str, Any],
**parameters: Unpack[Params], # type: ignore[misc]
) -> httpx.Response:
"""Make HTTP request(s) and return response object."""
...

@abstractmethod
def _stream_class(self) -> type[Stream[Out]]:
def _stream_class(self) -> type[Stream[Out, Params]]:
"""Return the Stream class for this client."""
...

@abstractmethod
def _make_stream_request(
self, request_body: dict[str, Any], **parameters: Unpack[Parameters]
self,
request_body: dict[str, Any],
**parameters: Unpack[Params], # type: ignore[misc]
) -> AsyncIterator[dict[str, Any]]:
"""Make HTTP streaming request and return async iterator of events."""
...
Expand Down Expand Up @@ -161,7 +179,9 @@ def _handle_error_response(self, response: httpx.Response) -> None:
)

def _transform_output(
self, content: object, **parameters: Unpack[Parameters]
self,
content: object,
**parameters: Unpack[Params], # type: ignore[misc]
) -> object:
"""Transform content using parameter mapper output transformations."""
for mapper in self.parameter_mappers():
Expand All @@ -171,7 +191,9 @@ def _transform_output(
return content

def _build_request(
self, inputs: In, **parameters: Unpack[Parameters]
self,
inputs: In,
**parameters: Unpack[Params], # type: ignore[misc]
) -> dict[str, Any]:
"""Build complete request by combining base request with parameters."""
request = self._init_request(inputs)
Expand All @@ -183,11 +205,13 @@ def _build_request(
return request


_clients: dict[tuple[Capability, Provider], type[Client]] = {}
_clients: dict[tuple[Capability, Provider], type[Client[Any, Any, Any]]] = {}


def register_client(
capability: Capability, provider: Provider, client_class: type[Client]
capability: Capability,
provider: Provider,
client_class: type[Client[Any, Any, Any]],
) -> None:
"""Register a provider-specific client class for a capability.

Expand All @@ -199,7 +223,9 @@ def register_client(
_clients[(capability, provider)] = client_class


def get_client_class(capability: Capability, provider: Provider) -> type[Client]:
def get_client_class(
capability: Capability, provider: Provider
) -> type[Client[Any, Any, Any]]:
"""Get the registered client class for a capability and provider.

Args:
Expand Down
13 changes: 12 additions & 1 deletion src/celeste/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,33 @@ class Range(Constraint):
"""Range constraint - value must be within min/max bounds.

If step is provided, value must be at min + (n * step) for some integer n.
If special_values is provided, those values bypass min/max validation.
"""

min: float | int
max: float | int
step: float | None = None
special_values: list[float | int] | None = None

def __call__(self, value: float | int) -> float | int:
"""Validate value is within range and matches step increment."""
if not isinstance(value, (int, float)):
msg = f"Must be numeric, got {type(value).__name__}"
raise TypeError(msg)

# Check if value is a special value that bypasses range check
if self.special_values is not None and value in self.special_values:
return value

# Validate range
if not self.min <= value <= self.max:
msg = f"Must be between {self.min} and {self.max}, got {value}"
special_msg = (
f" or one of {self.special_values}" if self.special_values else ""
)
msg = f"Must be between {self.min} and {self.max}{special_msg}, got {value}"
raise ValueError(msg)

# Validate step if provided
if self.step is not None:
remainder = (value - self.min) % self.step
# Use epsilon for floating-point comparison tolerance
Expand Down
2 changes: 1 addition & 1 deletion src/celeste/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def map(self, request: dict[str, Any], value: Any, model: Model) -> dict[str, An
"""
...

def parse_output(self, content: object, value: object | None) -> object:
def parse_output(self, content: Any, value: object | None) -> object: # noqa: ANN401
"""Optionally transform parsed content based on parameter value (default: return unchanged)."""
return content

Expand Down
11 changes: 7 additions & 4 deletions src/celeste/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,34 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from types import TracebackType
from typing import Any, Self
from typing import Any, Self, Unpack

from celeste.io import Chunk, Output
from celeste.parameters import Parameters


class Stream[Out: Output](ABC):
class Stream[Out: Output, Params: Parameters](ABC):
"""Async iterator wrapper providing final Output access after stream exhaustion."""

def __init__(
self,
sse_iterator: AsyncIterator[dict[str, Any]],
**parameters: Unpack[Params], # type: ignore[misc]
) -> None:
"""Initialize stream with SSE iterator."""
self._sse_iterator = sse_iterator
self._chunks: list[Chunk] = []
self._closed = False
self._output: Out | None = None
self._parameters = parameters

@abstractmethod
def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
"""Parse SSE event into Chunk (returns None to filter lifecycle events)."""
...

@abstractmethod
def _parse_output(self, chunks: list[Chunk]) -> Out:
def _parse_output(self, chunks: list[Chunk], **parameters: Unpack[Params]) -> Out: # type: ignore[misc]
"""Parse final Output from accumulated chunks."""
...

Expand Down Expand Up @@ -67,7 +70,7 @@ async def __anext__(self) -> Chunk:
msg = "Stream completed but no chunks were produced"
raise RuntimeError(msg)

self._output = self._parse_output(self._chunks)
self._output = self._parse_output(self._chunks, **self._parameters)
except Exception:
await self.aclose()
raise
Expand Down
Loading
Loading