Skip to content

Commit 0e9b36a

Browse files
authored
feat: SDK code (#3)
* feat: Added code * feat: Added setup * feat: more sdk code * feat: poetry * feat: poetry setup * ci: actions * fix: removed unused example * feat: added boto3 * feat: added dependencies * test: added pytest dependency * fix: python version * fix: update python version in lock * fix: format * fix: examples * test: removed example script for studio and added integration test instead * test: bedrock integration test * test: moved examples * ci: fixed inv * fix: lint * feat: version in init * fix: long content * fix: poetry version * fix: added __all__ * fix: Added code to __all__ * fix: prompt * fix: test action * fix: Added shebang * fix: long line * fix: loaded env for tests * fix: Added env * test: only 3.10 * test: default region * test: Added 3.8 * fix: subscriptable type * test: sagemaker tests * fix: used _http methods * fix: default values * ci: removed -vv flag * fix: imports * test: Added conditional skip * fix: CR fixes * fix: boto3 to pyproject.toml * fix: all-extras arg * fix: lint in action * feat: via param * fix: added all extras * fix: Added static type checker * feat: Moved body creationto function * feat: switched most responses to use dataclasses_json * feat: Added base mixin * fix: CR * fix: test path * fix: CR * feat: Added bedrock session * feat: Added SageMakerSession * fix: init of bedrock client * feat: More robust imports * fix: error message * fix: removed kwargs from request body * fix: Removed log_level from env * fix: logger calls * fix: Removed logger from init * feat: Added setup logger * ci: Added integration tests only on push to main * fix: removed unused import
1 parent a2841de commit 0e9b36a

File tree

125 files changed

+4748
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+4748
-0
lines changed

.github/workflows/publish.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# This workflow will upload a Python Package using Twine when a release is created
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3+
4+
name: Publish to PYPI
5+
6+
on:
7+
release:
8+
types: [published]
9+
10+
permissions:
11+
contents: read
12+
13+
jobs:
14+
deploy:
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: ["3.10"]
19+
20+
steps:
21+
- uses: actions/checkout@v3
22+
- name: Install Poetry
23+
run: |
24+
pipx install poetry
25+
- name: Set up Python
26+
uses: actions/setup-python@v4
27+
with:
28+
python-version: ${{ matrix.python-version }}
29+
cache: poetry
30+
cache-dependency-path: poetry.lock
31+
- name: Set Poetry environment
32+
run: |
33+
poetry env use ${{ matrix.python-version }}
34+
- name: Build package
35+
run: poetry build
36+
- name: Publish package to PYPI
37+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
38+
with:
39+
user: __token__
40+
password: ${{ secrets.PYPI_API_TOKEN }}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: Semantic Release
2+
3+
on:
4+
workflow_dispatch:
5+
6+
jobs:
7+
release:
8+
runs-on: ubuntu-latest
9+
concurrency: release
10+
permissions:
11+
id-token: write
12+
contents: write
13+
14+
steps:
15+
- uses: actions/checkout@v3
16+
with:
17+
fetch-depth: 0
18+
persist-credentials: false
19+
20+
- name: Python Semantic Release
21+
uses: python-semantic-release/python-semantic-release@v8.3.0
22+
with:
23+
github_token: ${{ secrets.GH_PAT_SEM_REL_ASAFG }}

.github/workflows/test.yaml

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
name: Test
2+
3+
on: [push]
4+
5+
env:
6+
POETRY_VERSION: "1.7.1"
7+
POETRY_URL: https://install.python-poetry.org
8+
9+
jobs:
10+
lint:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.10"]
15+
16+
steps:
17+
- name: Checkout
18+
uses: actions/checkout@v3
19+
- name: Install Poetry
20+
run: |
21+
pipx install poetry
22+
- name: Set up Python
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
cache: poetry
27+
cache-dependency-path: poetry.lock
28+
- name: Set Poetry environment
29+
run: |
30+
poetry env use ${{ matrix.python-version }}
31+
- name: Install dependencies
32+
run: |
33+
poetry install --no-root --only dev --all-extras
34+
- name: Lint Python (Black)
35+
run: |
36+
poetry run inv formatter
37+
- name: Lint Python (Ruff)
38+
run: |
39+
poetry run inv lint
40+
- name: Lint Python (isort)
41+
run: |
42+
poetry run inv isort
43+
unittests:
44+
runs-on: ubuntu-latest
45+
strategy:
46+
matrix:
47+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
48+
steps:
49+
- name: Checkout
50+
uses: actions/checkout@v3
51+
- name: Install Poetry
52+
run: |
53+
pipx install poetry
54+
- name: Set up Python
55+
uses: actions/setup-python@v4
56+
with:
57+
python-version: ${{ matrix.python-version }}
58+
cache: poetry
59+
cache-dependency-path: poetry.lock
60+
- name: Set Poetry environment
61+
run: |
62+
poetry env use ${{ matrix.python-version }}
63+
- name: Install dependencies
64+
run: |
65+
poetry install --all-extras
66+
- name: Run Tests
67+
env:
68+
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
69+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
70+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
71+
run: |
72+
poetry run pytest
73+
- name: Upload pytest test results
74+
uses: actions/upload-artifact@v3
75+
with:
76+
name: pytest-results-${{ matrix.python-version }}
77+
path: junit/test-results-${{ matrix.python-version }}.xml
78+
# Use always() to always run this step to publish test results when there are test failures
79+
if: ${{ always() }}
80+
81+
integration-tests:
82+
runs-on: ubuntu-latest
83+
84+
if: github.ref == 'refs/heads/main'
85+
86+
strategy:
87+
matrix:
88+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
89+
steps:
90+
- name: Checkout
91+
uses: actions/checkout@v3
92+
- name: Install Poetry
93+
run: |
94+
pipx install poetry
95+
- name: Set up Python
96+
uses: actions/setup-python@v4
97+
with:
98+
python-version: ${{ matrix.python-version }}
99+
cache: poetry
100+
cache-dependency-path: poetry.lock
101+
- name: Set Poetry environment
102+
run: |
103+
poetry env use ${{ matrix.python-version }}
104+
- name: Install dependencies
105+
run: |
106+
poetry install --all-extras
107+
- name: Run Integration Tests
108+
env:
109+
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
110+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
111+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
112+
run: |
113+
poetry run pytest tests/integration_tests/
114+
- name: Upload pytest integration tests results
115+
uses: actions/upload-artifact@v3
116+
with:
117+
name: pytest-results-${{ matrix.python-version }}
118+
path: junit/test-results-${{ matrix.python-version }}.xml
119+
# Use always() to always run this step to publish test results when there are test failures
120+
if: ${{ always() }}

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.10.6

ai21/__init__.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Any
2+
3+
from ai21.clients.studio.ai21_client import AI21Client
4+
from ai21.logger import setup_logger
5+
from ai21.resources.responses.answer_response import AnswerResponse
6+
from ai21.resources.responses.chat_response import ChatResponse
7+
from ai21.resources.responses.completion_response import CompletionsResponse
8+
from ai21.resources.responses.custom_model_response import CustomBaseModelResponse
9+
from ai21.resources.responses.dataset_response import DatasetResponse
10+
from ai21.resources.responses.embed_response import EmbedResponse
11+
from ai21.resources.responses.file_response import FileResponse
12+
from ai21.resources.responses.gec_response import GECResponse
13+
from ai21.resources.responses.improvement_response import ImprovementsResponse
14+
from ai21.resources.responses.library_answer_response import LibraryAnswerResponse
15+
from ai21.resources.responses.library_search_response import LibrarySearchResponse
16+
from ai21.resources.responses.paraphrase_response import ParaphraseResponse
17+
from ai21.resources.responses.segmentation_response import SegmentationResponse
18+
from ai21.resources.responses.summarize_by_segment_response import SummarizeBySegmentResponse
19+
from ai21.resources.responses.summarize_response import SummarizeResponse
20+
from ai21.services.sagemaker import SageMaker
21+
from ai21.version import VERSION
22+
23+
__version__ = VERSION
24+
setup_logger()
25+
26+
27+
def _import_bedrock_client():
28+
from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient
29+
30+
return AI21BedrockClient
31+
32+
33+
def _import_sagemaker_client():
34+
from ai21.clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient
35+
36+
return AI21SageMakerClient
37+
38+
39+
def _import_bedrock_model_id():
40+
from ai21.clients.bedrock.bedrock_model_id import BedrockModelID
41+
42+
return BedrockModelID
43+
44+
45+
def __getattr__(name: str) -> Any:
46+
try:
47+
if name == "AI21BedrockClient":
48+
return _import_bedrock_client()
49+
50+
if name == "AI21SageMakerClient":
51+
return _import_sagemaker_client()
52+
53+
if name == "BedrockModelID":
54+
return _import_bedrock_model_id()
55+
except ImportError as e:
56+
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e
57+
58+
59+
__all__ = [
60+
"AI21Client",
61+
"AI21BedrockClient",
62+
"AI21SageMakerClient",
63+
"BedrockModelID",
64+
"AnswerResponse",
65+
"ChatResponse",
66+
"CompletionsResponse",
67+
"CustomBaseModelResponse",
68+
"DatasetResponse",
69+
"EmbedResponse",
70+
"FileResponse",
71+
"GECResponse",
72+
"ImprovementsResponse",
73+
"LibraryAnswerResponse",
74+
"LibrarySearchResponse",
75+
"ParaphraseResponse",
76+
"SageMaker",
77+
"SegmentationResponse",
78+
"SummarizeBySegmentResponse",
79+
"SummarizeResponse",
80+
]

ai21/ai21_env_config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
import os
3+
from dataclasses import dataclass
4+
from typing import Optional
5+
6+
from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST
7+
8+
9+
@dataclass(frozen=True)
10+
class _AI21EnvConfig:
11+
api_key: Optional[str] = None
12+
api_url: Optional[str] = None
13+
api_version: str = DEFAULT_API_VERSION
14+
api_host: str = STUDIO_HOST
15+
organization: Optional[str] = None
16+
application: Optional[str] = None
17+
timeout_sec: Optional[int] = None
18+
num_retries: Optional[int] = None
19+
aws_region: Optional[str] = None
20+
log_level: Optional[str] = None
21+
22+
@classmethod
23+
def from_env(cls) -> _AI21EnvConfig:
24+
return cls(
25+
api_key=os.getenv("AI21_API_KEY"),
26+
api_url=os.getenv("AI21_API_URL"),
27+
api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION),
28+
api_host=os.getenv("AI21_API_HOST", STUDIO_HOST),
29+
organization=os.getenv("AI21_ORGANIZATION"),
30+
application=os.getenv("AI21_APPLICATION"),
31+
timeout_sec=os.getenv("AI21_TIMEOUT_SEC"),
32+
num_retries=os.getenv("AI21_NUM_RETRIES"),
33+
aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"),
34+
log_level=os.getenv("AI21_LOG_LEVEL", "info"),
35+
)
36+
37+
38+
AI21EnvConfig = _AI21EnvConfig.from_env()

ai21/ai21_studio_client.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Optional, Dict, Any
2+
3+
from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
4+
from ai21.errors import MissingApiKeyException
5+
from ai21.http_client import HttpClient
6+
from ai21.version import VERSION
7+
8+
9+
class AI21StudioClient:
10+
def __init__(
11+
self,
12+
*,
13+
api_key: Optional[str] = None,
14+
api_host: Optional[str] = None,
15+
api_version: Optional[str] = None,
16+
headers: Optional[Dict[str, Any]] = None,
17+
timeout_sec: Optional[int] = None,
18+
num_retries: Optional[int] = None,
19+
organization: Optional[str] = None,
20+
via: Optional[str] = None,
21+
env_config: _AI21EnvConfig = AI21EnvConfig,
22+
):
23+
self._env_config = env_config
24+
self._api_key = api_key or self._env_config.api_key
25+
26+
if self._api_key is None:
27+
raise MissingApiKeyException()
28+
29+
self._api_host = api_host or self._env_config.api_host
30+
self._api_version = api_version or self._env_config.api_version
31+
self._headers = headers
32+
self._timeout_sec = timeout_sec or self._env_config.timeout_sec
33+
self._num_retries = num_retries or self._env_config.num_retries
34+
self._organization = organization or self._env_config.organization
35+
self._application = self._env_config.application
36+
self._via = via
37+
38+
headers = self._build_headers(passed_headers=headers)
39+
40+
self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers)
41+
42+
def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
43+
headers = {
44+
"Content-Type": "application/json",
45+
"User-Agent": self._build_user_agent(),
46+
}
47+
48+
if self._api_key:
49+
headers["Authorization"] = f"Bearer {self._api_key}"
50+
51+
if passed_headers is not None:
52+
headers.update(passed_headers)
53+
54+
return headers
55+
56+
def _build_user_agent(self) -> str:
57+
user_agent = f"ai21 studio SDK {VERSION}"
58+
59+
if self._organization is not None:
60+
user_agent = f"{user_agent} organization: {self._organization}"
61+
62+
if self._application is not None:
63+
user_agent = f"{user_agent} application: {self._application}"
64+
65+
if self._via is not None:
66+
user_agent = f"{user_agent} via: {self._via}"
67+
68+
return user_agent
69+
70+
def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None):
71+
return self.http_client.execute_http_request(method=method, url=url, params=params, files=files)
72+
73+
def get_base_url(self) -> str:
74+
return f"{self._api_host}/studio/{self._api_version}"

ai21/clients/__init__.py

Whitespace-only changes.

ai21/clients/bedrock/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)