Skip to content

Commit a6ef5cf

Browse files
feat: add Vertex AI support for all providers (#135)
* feat: add Vertex AI support for all providers Route requests through Vertex AI when GoogleADC auth is provided. Supports Google, Anthropic, Mistral, and DeepSeek providers across text, images, and videos modalities. Includes Veo polling fix (fetchPredictOperation), error handler hardening, Gemini image role fix, and DeepSeek usage parser fix. WIP: Veo Vertex inline video (bytesBase64Encoded) parsing not yet handled — needs base64 decoding or storageUri in request. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(ci): install gcp extra so Vertex routing tests can import google-auth The google-auth package is optional under [gcp], but unit tests in test_vertex_routing.py and Vertex integration tests need it importable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(veo): handle Vertex inline bytesBase64Encoded and videoGcsUri key mismatch Vertex Veo responses use videoGcsUri (not uri/gcsUri) and can return inline base64 instead of a GCS URL. Normalize the key and decode inline responses directly into VideoArtifact. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(embeddings): adapt request/response format for Vertex :predict endpoint Vertex embeddings uses :predict with instances format, not :embedContent. Build correct request body in _init_request when auth is GoogleADC, and parse predictions response format in _parse_content. Add integration test. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: fix trailing newlines in workflow files Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(vertex): move embeddings auth to provider mixin, rename VertexEndpoint, update templates - Move isinstance(self.auth, GoogleADC) check from modality _init_request() to provider mixin _make_request() for embeddings, keeping auth logic in provider layer - Fix misplaced class docstring in GoogleEmbeddingsClient mixin - Rename VertexEndpoint to VertexGenerateContentEndpoint for consistency with VertexImagenEndpoint, VertexEmbeddingsEndpoint, etc. - Add Vertex AI routing patterns (commented) to provider templates Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(vertex): centralize URL building in GoogleADC.build_url() Move duplicated project_id validation, base URL resolution, and endpoint formatting from 7 provider _build_url() methods into GoogleADC.build_url(). Also remove manual base64.b64decode from video client (Artifact validator handles it). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(images): let Artifact validator handle base64 decoding in Gemini images Same pattern as the video client fix - pass base64 string directly to ImageArtifact(data=...) instead of manual base64.b64decode(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(images): let Artifact validator handle base64 decoding in Imagen Same pattern as Gemini images and Veo video fixes - pass base64 string directly to ImageArtifact(data=...) instead of manual base64.b64decode(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent af37ace commit a6ef5cf

File tree

32 files changed

+1093
-61
lines changed

32 files changed

+1093
-61
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ jobs:
8585
- uses: ./.github/actions/setup-python-uv
8686
with:
8787
python-version: ${{ matrix.python-version }}
88+
- run: uv sync --extra gcp
8889
- run: uv run pytest tests/unit_tests -v --cov=celeste --cov-report=term-missing --cov-report=xml --cov-report=html --cov-fail-under=80
8990
- uses: codecov/codecov-action@v4
9091
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.12'

.github/workflows/claude-code-review.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,3 @@ jobs:
4141
prompt: '/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}'
4242
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
4343
# or https://code.claude.com/docs/en/cli-reference for available options
44-

.github/workflows/claude.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,3 @@ jobs:
4747
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
4848
# or https://code.claude.com/docs/en/cli-reference for available options
4949
# claude_args: '--allowed-tools Bash(gh pr:*)'
50-

.github/workflows/publish.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ jobs:
5858
workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
5959
service_account: ${{ secrets.GCP_SERVICE_ACCOUNT }}
6060
- uses: ./.github/actions/setup-python-uv
61+
- run: uv sync --extra gcp
6162
- name: Run ${{ matrix.package }} integration tests
6263
env:
6364
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

src/celeste/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,8 @@ def _handle_error_response(self, response: httpx.Response) -> None:
296296
"""Handle error responses from provider APIs."""
297297
if not response.is_success:
298298
try:
299-
error_data = response.json()
300-
error_msg = error_data.get("error", {}).get("message", response.text)
301-
except JSONDecodeError:
299+
error_msg = response.json()["error"]["message"]
300+
except (JSONDecodeError, KeyError, TypeError, IndexError):
302301
error_msg = response.text or f"HTTP {response.status_code}"
303302

304303
raise httpx.HTTPStatusError(

src/celeste/modalities/images/providers/google/gemini.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
8282
parts.append({"text": inputs.prompt})
8383

8484
return {
85-
"contents": [{"parts": parts}],
85+
"contents": [{"role": "user", "parts": parts}],
8686
"generationConfig": {
8787
"responseModalities": ["TEXT", "IMAGE"],
8888
"imageConfig": {},
@@ -113,8 +113,7 @@ def _parse_content(
113113
if not base64_data:
114114
continue
115115
mime_type = ImageMimeType(inline_data.get("mimeType", "image/png"))
116-
image_bytes = base64.b64decode(base64_data)
117-
artifacts.append(ImageArtifact(data=image_bytes, mime_type=mime_type))
116+
artifacts.append(ImageArtifact(data=base64_data, mime_type=mime_type))
118117

119118
if not artifacts:
120119
return ImageArtifact()

src/celeste/modalities/images/providers/google/imagen.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Imagen client for Google images modality."""
22

3-
import base64
43
from typing import Any, Unpack
54

65
from celeste.artifacts import ImageArtifact
@@ -61,8 +60,7 @@ def _parse_content(
6160
if not base64_data:
6261
continue
6362
mime_type = ImageMimeType(prediction.get("mimeType", "image/png"))
64-
image_bytes = base64.b64decode(base64_data)
65-
images.append(ImageArtifact(data=image_bytes, mime_type=mime_type))
63+
images.append(ImageArtifact(data=base64_data, mime_type=mime_type))
6664

6765
num_images_requested = parameters.get("num_images")
6866
if num_images_requested == 1:

src/celeste/modalities/videos/providers/google/client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Unpack
44

55
from celeste.artifacts import VideoArtifact
6+
from celeste.mime_types import VideoMimeType
67
from celeste.parameters import ParameterMapper
78
from celeste.providers.google.veo import config
89
from celeste.providers.google.veo.client import GoogleVeoClient as GoogleVeoMixin
@@ -51,6 +52,12 @@ def _parse_content(
5152
) -> VideoArtifact:
5253
"""Parse content from response."""
5354
video_data = super()._parse_content(response_data)
55+
# Handle inline base64 response (Vertex can return bytesBase64Encoded)
56+
if "bytesBase64Encoded" in video_data:
57+
mime_type = video_data.get("mimeType", VideoMimeType.MP4)
58+
return VideoArtifact(
59+
data=video_data["bytesBase64Encoded"], mime_type=mime_type
60+
)
5461
return VideoArtifact(url=video_data.get("uri"))
5562

5663
def _parse_finish_reason(self, response_data: dict[str, Any]) -> VideoFinishReason:
@@ -62,13 +69,16 @@ async def download_content(self, artifact: VideoArtifact) -> VideoArtifact:
6269
"""Download video content from GCS URL.
6370
6471
Args:
65-
artifact: VideoArtifact with URL to download.
72+
artifact: VideoArtifact with URL or inline data to download.
6673
6774
Returns:
6875
VideoArtifact with downloaded bytes data.
6976
"""
77+
if artifact.data is not None:
78+
return artifact
79+
7080
if artifact.url is None:
71-
msg = "Artifact has no URL to download"
81+
msg = "Artifact has no URL or data to download"
7282
raise ValueError(msg)
7383

7484
video_bytes = await super().download_content(artifact.url)

src/celeste/providers/anthropic/messages/client.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from celeste.core import UsageField
88
from celeste.io import FinishReason
99
from celeste.mime_types import ApplicationMimeType
10+
from celeste.providers.google.auth import GoogleADC
1011

1112
from . import config
1213

@@ -22,6 +23,10 @@ class AnthropicMessagesClient(APIMixin):
2223
- _parse_finish_reason() - Extract finish reason from response
2324
- _build_metadata() - Filter content fields
2425
26+
Auth-based endpoint selection:
27+
- GoogleADC auth -> Vertex AI endpoints (Claude on Google Cloud)
28+
- API key auth -> Anthropic API endpoints
29+
2530
Usage:
2631
class AnthropicTextGenerationClient(AnthropicMessagesClient, TextGenerationClient):
2732
def _parse_content(self, response_data, **parameters):
@@ -32,6 +37,23 @@ def _parse_content(self, response_data, **parameters):
3237
return ""
3338
"""
3439

40+
def _get_vertex_endpoint(
41+
self, anthropic_endpoint: str, streaming: bool = False
42+
) -> str:
43+
"""Map Anthropic endpoint to Vertex AI endpoint."""
44+
if streaming:
45+
return config.VertexAnthropicEndpoint.STREAM_MESSAGE
46+
return config.VertexAnthropicEndpoint.CREATE_MESSAGE
47+
48+
def _build_url(self, endpoint: str, streaming: bool = False) -> str:
49+
"""Build full URL based on auth type."""
50+
if isinstance(self.auth, GoogleADC):
51+
return self.auth.build_url(
52+
self._get_vertex_endpoint(endpoint, streaming=streaming),
53+
model_id=self.model.id,
54+
)
55+
return f"{config.BASE_URL}{endpoint}"
56+
3557
def _build_headers(self, request_body: dict[str, Any]) -> dict[str, str]:
3658
"""Build headers with beta features extracted from request."""
3759
beta_features: list[str] = request_body.pop("_beta_features", [])
@@ -85,7 +107,7 @@ async def _make_request(
85107
endpoint = config.AnthropicMessagesEndpoint.CREATE_MESSAGE
86108

87109
response = await self.http_client.post(
88-
f"{config.BASE_URL}{endpoint}",
110+
url=self._build_url(endpoint, streaming=False),
89111
headers=headers,
90112
json_body=request_body,
91113
)
@@ -111,7 +133,7 @@ def _make_stream_request(
111133
endpoint = config.AnthropicMessagesEndpoint.CREATE_MESSAGE
112134

113135
return self.http_client.stream_post(
114-
f"{config.BASE_URL}{endpoint}",
136+
url=self._build_url(endpoint, streaming=True),
115137
headers=headers,
116138
json_body=request_body,
117139
)

src/celeste/providers/anthropic/messages/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ class AnthropicMessagesEndpoint(StrEnum):
1212
GET_MODEL = "/v1/models/{model_id}"
1313

1414

15+
class VertexAnthropicEndpoint(StrEnum):
16+
"""Endpoints for Anthropic on Vertex AI."""
17+
18+
CREATE_MESSAGE = "/v1/projects/{project_id}/locations/{location}/publishers/anthropic/models/{model_id}:rawPredict"
19+
STREAM_MESSAGE = "/v1/projects/{project_id}/locations/{location}/publishers/anthropic/models/{model_id}:streamRawPredict"
20+
21+
1522
BASE_URL = "https://api.anthropic.com"
1623

1724
# Required

0 commit comments

Comments
 (0)