Skip to content

Commit 2c5107b

Browse files
committed
feat: support NOT_GIVEN type
1 parent 8965828 commit 2c5107b

File tree

5 files changed

+49
-13
lines changed

5 files changed

+49
-13
lines changed

ai21/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ai21.logger import setup_logger
1414
from ai21.services.sagemaker import SageMaker
1515
from ai21.version import VERSION
16+
from ai21.types import NOT_GIVEN
1617

1718
__version__ = VERSION
1819
setup_logger()
@@ -63,4 +64,5 @@ def __getattr__(name: str) -> Any:
6364
"AI21SageMakerClient",
6465
"BedrockModelID",
6566
"SageMaker",
67+
"NOT_GIVEN",
6668
]

ai21/clients/studio/resources/studio_completion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from __future__ import annotations
2+
13
from typing import Optional, List
24

35
from ai21.clients.common.completion_base import Completion
46
from ai21.clients.studio.resources.studio_resource import StudioResource
57
from ai21.models import Penalty, CompletionsResponse
8+
from ai21.types import NOT_GIVEN
69

710

811
class StudioCompletion(StudioResource, Completion):
@@ -11,7 +14,7 @@ def create(
1114
model: str,
1215
prompt: str,
1316
*,
14-
max_tokens: Optional[int] = None,
17+
max_tokens: int | NOT_GIVEN = None,
1518
num_results: Optional[int] = 1,
1619
min_tokens: Optional[int] = 0,
1720
temperature: Optional[float] = 0.7,

ai21/types.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing_extensions import Literal
2+
3+
4+
# Sentinel class used until PEP 0661 is accepted
5+
class NotGiven:
6+
"""
7+
A sentinel singleton class used to distinguish omitted keyword arguments
8+
from those passed in with the value None (which may have different behavior).
9+
10+
For example:
11+
12+
```py
13+
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
14+
...
15+
16+
17+
get(timeout=1) # 1s timeout
18+
get(timeout=None) # No timeout
19+
get() # Default timeout behavior, which may not be statically known at the method definition.
20+
```
21+
"""
22+
23+
def __bool__(self) -> Literal[False]:
24+
return False
25+
26+
def __repr__(self) -> str:
27+
return "NOT_GIVEN"
28+
29+
30+
NOT_GIVEN = NotGiven()

poetry.lock

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ requests = "^2.31.0"
6060
ai21-tokenizer = "^0.3.9"
6161
boto3 = { version = "^1.28.82", optional = true }
6262
dataclasses-json = "^0.6.3"
63+
typing-extensions = "^4.9.0"
6364

6465

6566
[tool.poetry.group.dev.dependencies]

0 commit comments

Comments
 (0)