Skip to content

Commit 08a5b41

Browse files
authored
feat: Added launchpad client (#285)
* feat: Added launchpad client * test: Added imports to test
1 parent 4175595 commit 08a5b41

File tree

13 files changed

+293
-9
lines changed

13 files changed

+293
-9
lines changed

.git-hooks/check_api_key.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22

33
# Check for `api_key=` in staged changes
4-
if git diff --cached | grep -q "api_key="; then
4+
if git diff --cached | grep -q -E '\bapi_key=[^"]'; then
55
echo "❌ Commit blocked: Found 'api_key=' in staged changes."
66
exit 1 # Prevent commit
77
fi

ai21/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
from ai21.clients.azure.ai21_azure_client import AI21AzureClient, AsyncAI21AzureClient
55
from ai21.clients.studio.ai21_client import AI21Client
66
from ai21.clients.studio.async_ai21_client import AsyncAI21Client
7-
87
from ai21.errors import (
98
AI21APIError,
9+
AI21Error,
1010
APITimeoutError,
1111
MissingApiKeyError,
1212
ModelPackageDoesntExistError,
13-
AI21Error,
1413
TooManyRequestsError,
1514
)
1615
from ai21.logger import setup_logger
1716
from ai21.version import VERSION
1817

18+
1919
__version__ = VERSION
2020
setup_logger()
2121

@@ -44,6 +44,18 @@ def _import_vertex_client():
4444
return AI21VertexClient
4545

4646

47+
def _import_launchpad_client():
48+
from ai21.clients.launchpad.ai21_launchpad_client import AI21LaunchpadClient
49+
50+
return AI21LaunchpadClient
51+
52+
53+
def _import_async_launchpad_client():
54+
from ai21.clients.launchpad.ai21_launchpad_client import AsyncAI21LaunchpadClient
55+
56+
return AsyncAI21LaunchpadClient
57+
58+
4759
def _import_async_vertex_client():
4860
from ai21.clients.vertex.ai21_vertex_client import AsyncAI21VertexClient
4961

@@ -66,6 +78,13 @@ def __getattr__(name: str) -> Any:
6678

6779
if name == "AsyncAI21VertexClient":
6880
return _import_async_vertex_client()
81+
82+
if name == "AI21LaunchpadClient":
83+
return _import_launchpad_client()
84+
85+
if name == "AsyncAI21LaunchpadClient":
86+
return _import_async_launchpad_client()
87+
6988
except ImportError as e:
7089
raise ImportError('Please install "ai21[AWS]" for Bedrock, or "ai21[Vertex]" for Vertex') from e
7190

@@ -87,4 +106,6 @@ def __getattr__(name: str) -> Any:
87106
"AsyncAI21BedrockClient",
88107
"AI21VertexClient",
89108
"AsyncAI21VertexClient",
109+
"AI21LaunchpadClient",
110+
"AsyncAI21LaunchpadClient",
90111
]

ai21/clients/common/auth/__init__.py

Whitespace-only changes.

ai21/clients/launchpad/__init__.py

Whitespace-only changes.
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, Optional
4+
5+
import httpx
6+
7+
from google.auth.credentials import Credentials as GCPCredentials
8+
9+
from ai21.clients.common.auth.gcp_authorization import GCPAuthorization
10+
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
11+
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
12+
from ai21.http_client.http_client import AI21HTTPClient
13+
from ai21.models.request_options import RequestOptions
14+
15+
16+
_DEFAULT_GCP_REGION = "us-central1"
17+
_LAUNCHPAD_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1"
18+
_LAUNCHPAD_PATH_FORMAT = "/projects/{project_id}/locations/{region}/endpoints/{endpoint_id}:{endpoint}"
19+
20+
21+
class BaseAI21LaunchpadClient:
22+
def __init__(
23+
self,
24+
region: Optional[str] = None,
25+
project_id: Optional[str] = None,
26+
endpoint_id: Optional[str] = None,
27+
access_token: Optional[str] = None,
28+
credentials: Optional[GCPCredentials] = None,
29+
):
30+
if access_token is not None and project_id is None:
31+
raise ValueError("Field project_id is required when setting access_token")
32+
self._region = region or _DEFAULT_GCP_REGION
33+
self._access_token = access_token
34+
self._project_id = project_id
35+
self._endpoint_id = endpoint_id
36+
self._credentials = credentials
37+
self._gcp_auth = GCPAuthorization()
38+
39+
def _get_base_url(self) -> str:
40+
return _LAUNCHPAD_BASE_URL_FORMAT.format(region=self._region)
41+
42+
def _get_access_token(self) -> str:
43+
if self._access_token is not None:
44+
return self._access_token
45+
46+
if self._credentials is None:
47+
self._credentials, self._project_id = self._gcp_auth.get_gcp_credentials(
48+
project_id=self._project_id,
49+
)
50+
51+
if self._credentials is None:
52+
raise ValueError("Could not get credentials for GCP project")
53+
54+
self._gcp_auth.refresh_auth(self._credentials)
55+
56+
if self._credentials.token is None:
57+
raise RuntimeError(f"Could not get access token for GCP project {self._project_id}")
58+
59+
return self._credentials.token
60+
61+
def _build_path(
62+
self,
63+
project_id: str,
64+
region: str,
65+
model: str,
66+
endpoint: str,
67+
) -> str:
68+
return _LAUNCHPAD_PATH_FORMAT.format(
69+
project_id=project_id,
70+
region=region,
71+
endpoint_id=self._endpoint_id,
72+
model=model,
73+
endpoint=endpoint,
74+
)
75+
76+
def _get_authorization_header(self) -> Dict[str, Any]:
77+
access_token = self._get_access_token()
78+
return {"Authorization": f"Bearer {access_token}"}
79+
80+
81+
class AI21LaunchpadClient(BaseAI21LaunchpadClient, AI21HTTPClient):
82+
def __init__(
83+
self,
84+
region: Optional[str] = None,
85+
project_id: Optional[str] = None,
86+
endpoint_id: Optional[str] = None,
87+
base_url: Optional[str] = None,
88+
access_token: Optional[str] = None,
89+
credentials: Optional[GCPCredentials] = None,
90+
headers: Dict[str, str] | None = None,
91+
timeout_sec: Optional[float] = None,
92+
num_retries: Optional[int] = None,
93+
http_client: Optional[httpx.Client] = None,
94+
):
95+
BaseAI21LaunchpadClient.__init__(
96+
self,
97+
region=region,
98+
project_id=project_id,
99+
endpoint_id=endpoint_id,
100+
access_token=access_token,
101+
credentials=credentials,
102+
)
103+
104+
if base_url is None:
105+
base_url = self._get_base_url()
106+
107+
AI21HTTPClient.__init__(
108+
self,
109+
base_url=base_url,
110+
timeout_sec=timeout_sec,
111+
num_retries=num_retries,
112+
headers=headers,
113+
client=http_client,
114+
requires_api_key=False,
115+
)
116+
117+
self.chat = StudioChat(self)
118+
# Override the chat.create method to match the completions endpoint,
119+
# so it wouldn't get to the old J2 completion endpoint
120+
self.chat.create = self.chat.completions.create
121+
122+
def _build_request(self, options: RequestOptions) -> httpx.Request:
123+
options = self._prepare_options(options)
124+
125+
return super()._build_request(options)
126+
127+
def _prepare_options(self, options: RequestOptions) -> RequestOptions:
128+
body = options.body
129+
130+
model = body.pop("model")
131+
stream = body.pop("stream", False)
132+
endpoint = "streamRawPredict" if stream else "rawPredict"
133+
headers = self._prepare_headers()
134+
path = self._build_path(
135+
project_id=self._project_id,
136+
region=self._region,
137+
model=model,
138+
endpoint=endpoint,
139+
)
140+
141+
return options.replace(
142+
body=body,
143+
path=path,
144+
headers=headers,
145+
)
146+
147+
def _prepare_headers(self) -> Dict[str, Any]:
148+
return self._get_authorization_header()
149+
150+
151+
class AsyncAI21LaunchpadClient(BaseAI21LaunchpadClient, AsyncAI21HTTPClient):
152+
def __init__(
153+
self,
154+
region: Optional[str] = None,
155+
project_id: Optional[str] = None,
156+
endpoint_id: Optional[str] = None,
157+
base_url: Optional[str] = None,
158+
access_token: Optional[str] = None,
159+
credentials: Optional[GCPCredentials] = None,
160+
headers: Dict[str, str] | None = None,
161+
timeout_sec: Optional[float] = None,
162+
num_retries: Optional[int] = None,
163+
http_client: Optional[httpx.AsyncClient] = None,
164+
):
165+
BaseAI21LaunchpadClient.__init__(
166+
self,
167+
region=region,
168+
project_id=project_id,
169+
endpoint_id=endpoint_id,
170+
access_token=access_token,
171+
credentials=credentials,
172+
)
173+
174+
if base_url is None:
175+
base_url = self._get_base_url()
176+
177+
AsyncAI21HTTPClient.__init__(
178+
self,
179+
base_url=base_url,
180+
timeout_sec=timeout_sec,
181+
num_retries=num_retries,
182+
headers=headers,
183+
client=http_client,
184+
requires_api_key=False,
185+
)
186+
187+
self.chat = AsyncStudioChat(self)
188+
# Override the chat.create method to match the completions endpoint,
189+
# so it wouldn't get to the old J2 completion endpoint
190+
self.chat.create = self.chat.completions.create
191+
192+
def _build_request(self, options: RequestOptions) -> httpx.Request:
193+
options = self._prepare_options(options)
194+
195+
return super()._build_request(options)
196+
197+
def _prepare_options(self, options: RequestOptions) -> RequestOptions:
198+
body = options.body
199+
200+
model = body.pop("model")
201+
stream = body.pop("stream", False)
202+
endpoint = "streamRawPredict" if stream else "rawPredict"
203+
headers = self._prepare_headers()
204+
path = self._build_path(
205+
project_id=self._project_id,
206+
region=self._region,
207+
model=model,
208+
endpoint=endpoint,
209+
)
210+
211+
return options.replace(
212+
body=body,
213+
path=path,
214+
headers=headers,
215+
)
216+
217+
def _prepare_headers(self) -> Dict[str, Any]:
218+
return self._get_authorization_header()

ai21/clients/vertex/ai21_vertex_client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import Optional, Dict, Any
3+
from typing import Any, Dict, Optional
44

55
import httpx
6+
67
from google.auth.credentials import Credentials as GCPCredentials
78

8-
from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
9-
from ai21.clients.vertex.gcp_authorization import GCPAuthorization
9+
from ai21.clients.common.auth.gcp_authorization import GCPAuthorization
10+
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
1011
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
1112
from ai21.http_client.http_client import AI21HTTPClient
1213
from ai21.models.request_options import RequestOptions
1314

15+
1416
_DEFAULT_GCP_REGION = "us-central1"
1517
_VERTEX_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1"
1618
_VERTEX_PATH_FORMAT = "/projects/{project_id}/locations/{region}/publishers/ai21/models/{model}:{endpoint}"

examples/launchpad/__init__.py

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import asyncio
2+
3+
from ai21 import AsyncAI21LaunchpadClient
4+
from ai21.models.chat import ChatMessage
5+
6+
7+
client = AsyncAI21LaunchpadClient(endpoint_id="<your_endpoint_id>")
8+
9+
10+
async def main():
11+
messages = ChatMessage(content="What is the meaning of life?", role="user")
12+
13+
completion = await client.chat.completions.create(
14+
model="jamba-1.6-large",
15+
messages=[messages],
16+
)
17+
18+
print(completion)
19+
20+
21+
asyncio.run(main())
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from ai21 import AI21LaunchpadClient
2+
from ai21.models.chat import ChatMessage
3+
4+
5+
client = AI21LaunchpadClient(endpoint_id="<your_endpoint_id>")
6+
7+
messages = ChatMessage(content="What is the meaning of life?", role="user")
8+
9+
completion = client.chat.completions.create(
10+
model="jamba-1.6-large",
11+
messages=[messages],
12+
stream=True,
13+
)
14+
15+
16+
print(completion)

0 commit comments

Comments
 (0)