Skip to content

Commit 001acc5

Browse files
committed
feat: support NOT_GIVEN type
1 parent 8965828 commit 001acc5

File tree

16 files changed

+209
-148
lines changed

16 files changed

+209
-148
lines changed

ai21/clients/bedrock/resources/bedrock_completion.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Optional, List
1+
from typing import List
22

33
from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource
44
from ai21.models import Penalty, CompletionsResponse
5+
from ai21.types import NotGiven, NOT_GIVEN
6+
from ai21.utils.typing import remove_not_given
57

68

79
class BedrockCompletion(BedrockResource):
@@ -10,37 +12,33 @@ def create(
1012
model_id: str,
1113
prompt: str,
1214
*,
13-
max_tokens: Optional[int] = None,
14-
num_results: Optional[int] = 1,
15-
min_tokens: Optional[int] = 0,
16-
temperature: Optional[float] = 0.7,
17-
top_p: Optional[int] = 1,
18-
top_k_return: Optional[int] = 0,
19-
stop_sequences: Optional[List[str]] = None,
20-
frequency_penalty: Optional[Penalty] = None,
21-
presence_penalty: Optional[Penalty] = None,
22-
count_penalty: Optional[Penalty] = None,
15+
max_tokens: int | NotGiven = NOT_GIVEN,
16+
num_results: int | NotGiven = NOT_GIVEN,
17+
min_tokens: int | NotGiven = NOT_GIVEN,
18+
temperature: float | NotGiven = NOT_GIVEN,
19+
top_p: float | NotGiven = NOT_GIVEN,
20+
top_k_return: int | NotGiven = NOT_GIVEN,
21+
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
22+
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
23+
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
24+
count_penalty: Penalty | NotGiven = NOT_GIVEN,
2325
**kwargs,
2426
) -> CompletionsResponse:
25-
body = {
26-
"prompt": prompt,
27-
"maxTokens": max_tokens,
28-
"numResults": num_results,
29-
"minTokens": min_tokens,
30-
"temperature": temperature,
31-
"topP": top_p,
32-
"topKReturn": top_k_return,
33-
"stopSequences": stop_sequences or [],
34-
}
35-
36-
if frequency_penalty is not None:
37-
body["frequencyPenalty"] = frequency_penalty.to_dict()
38-
39-
if presence_penalty is not None:
40-
body["presencePenalty"] = presence_penalty.to_dict()
41-
42-
if count_penalty is not None:
43-
body["countPenalty"] = count_penalty.to_dict()
27+
body = remove_not_given(
28+
{
29+
"prompt": prompt,
30+
"maxTokens": max_tokens,
31+
"numResults": num_results,
32+
"minTokens": min_tokens,
33+
"temperature": temperature,
34+
"topP": top_p,
35+
"topKReturn": top_k_return,
36+
"stopSequences": stop_sequences or [],
37+
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
38+
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
39+
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
40+
}
41+
)
4442

4543
raw_response = self._invoke(model_id=model_id, body=body)
4644

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from typing import Optional, List, Dict, Any
4+
from typing import List, Dict, Any
35

46
from ai21.models import Penalty, CompletionsResponse
7+
from ai21.types import NOT_GIVEN, NotGiven
8+
from ai21.utils.typing import remove_not_given
59

610

711
class Completion(ABC):
@@ -13,18 +17,18 @@ def create(
1317
model: str,
1418
prompt: str,
1519
*,
16-
max_tokens: int = 64,
17-
num_results: int = 1,
18-
min_tokens=0,
19-
temperature=0.7,
20-
top_p=1,
21-
top_k_return=0,
22-
custom_model: Optional[str] = None,
23-
stop_sequences: Optional[List[str]] = (),
24-
frequency_penalty: Optional[Penalty] = None,
25-
presence_penalty: Optional[Penalty] = None,
26-
count_penalty: Optional[Penalty] = None,
27-
epoch: Optional[int] = None,
20+
max_tokens: int | NotGiven = NOT_GIVEN,
21+
num_results: int | NotGiven = NOT_GIVEN,
22+
min_tokens: int | NotGiven = NOT_GIVEN,
23+
temperature: float | NOT_GIVEN = NOT_GIVEN,
24+
top_p: float | NotGiven = NOT_GIVEN,
25+
top_k_return: int | NotGiven = NOT_GIVEN,
26+
custom_model: str | NotGiven = NOT_GIVEN,
27+
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
28+
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
29+
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
30+
count_penalty: Penalty | NotGiven = NOT_GIVEN,
31+
epoch: int | NotGiven = NOT_GIVEN,
2832
**kwargs,
2933
) -> CompletionsResponse:
3034
"""
@@ -54,32 +58,34 @@ def _create_body(
5458
self,
5559
model: str,
5660
prompt: str,
57-
max_tokens: Optional[int],
58-
num_results: Optional[int],
59-
min_tokens: Optional[int],
60-
temperature: Optional[float],
61-
top_p: Optional[int],
62-
top_k_return: Optional[int],
63-
custom_model: Optional[str],
64-
stop_sequences: Optional[List[str]],
65-
frequency_penalty: Optional[Penalty],
66-
presence_penalty: Optional[Penalty],
67-
count_penalty: Optional[Penalty],
68-
epoch: Optional[int],
61+
max_tokens: int | NotGiven,
62+
num_results: int | NotGiven,
63+
min_tokens: int | NotGiven,
64+
temperature: float | NotGiven,
65+
top_p: float | NotGiven,
66+
top_k_return: int | NotGiven,
67+
custom_model: str | NotGiven,
68+
stop_sequences: List[str] | NotGiven,
69+
frequency_penalty: Penalty | NotGiven,
70+
presence_penalty: Penalty | NotGiven,
71+
count_penalty: Penalty | NotGiven,
72+
epoch: int | NotGiven,
6973
):
70-
return {
71-
"model": model,
72-
"customModel": custom_model,
73-
"prompt": prompt,
74-
"maxTokens": max_tokens,
75-
"numResults": num_results,
76-
"minTokens": min_tokens,
77-
"temperature": temperature,
78-
"topP": top_p,
79-
"topKReturn": top_k_return,
80-
"stopSequences": stop_sequences or [],
81-
"frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(),
82-
"presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(),
83-
"countPenalty": None if count_penalty is None else count_penalty.to_dict(),
84-
"epoch": epoch,
85-
}
74+
return remove_not_given(
75+
{
76+
"model": model,
77+
"customModel": custom_model,
78+
"prompt": prompt,
79+
"maxTokens": max_tokens,
80+
"numResults": num_results,
81+
"minTokens": min_tokens,
82+
"temperature": temperature,
83+
"topP": top_p,
84+
"topKReturn": top_k_return,
85+
"stopSequences": stop_sequences,
86+
"frequencyPenalty": NOT_GIVEN if frequency_penalty is NOT_GIVEN else frequency_penalty.to_dict(),
87+
"presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(),
88+
"countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(),
89+
"epoch": epoch,
90+
}
91+
)

ai21/clients/sagemaker/resources/sagemaker_completion.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,43 @@
1-
from typing import Optional, List
1+
from typing import List
22

33
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
44
from ai21.models import Penalty, CompletionsResponse
5+
from ai21.types import NotGiven, NOT_GIVEN
6+
from ai21.utils.typing import remove_not_given
57

68

79
class SageMakerCompletion(SageMakerResource):
810
def create(
911
self,
1012
prompt: str,
1113
*,
12-
max_tokens: Optional[int] = None,
13-
num_results: Optional[int] = 1,
14-
min_tokens: Optional[int] = 0,
15-
temperature: Optional[float] = 0.7,
16-
top_p: Optional[int] = 1,
17-
top_k_return: Optional[int] = 0,
18-
stop_sequences: Optional[List[str]] = None,
19-
frequency_penalty: Optional[Penalty] = None,
20-
presence_penalty: Optional[Penalty] = None,
21-
count_penalty: Optional[Penalty] = None,
14+
max_tokens: int | NotGiven = NOT_GIVEN,
15+
num_results: int | NotGiven = NOT_GIVEN,
16+
min_tokens: int | NotGiven = NOT_GIVEN,
17+
temperature: float | NotGiven = NOT_GIVEN,
18+
top_p: float | NotGiven = NOT_GIVEN,
19+
top_k_return: int | NotGiven = NOT_GIVEN,
20+
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
21+
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
22+
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
23+
count_penalty: Penalty | NotGiven = NOT_GIVEN,
2224
**kwargs,
2325
) -> CompletionsResponse:
24-
body = {
25-
"prompt": prompt,
26-
"maxTokens": max_tokens,
27-
"numResults": num_results,
28-
"minTokens": min_tokens,
29-
"temperature": temperature,
30-
"topP": top_p,
31-
"topKReturn": top_k_return,
32-
"stopSequences": stop_sequences or [],
33-
}
34-
35-
if frequency_penalty is not None:
36-
body["frequencyPenalty"] = frequency_penalty.to_dict()
37-
38-
if presence_penalty is not None:
39-
body["presencePenalty"] = presence_penalty.to_dict()
40-
41-
if count_penalty is not None:
42-
body["countPenalty"] = count_penalty.to_dict()
26+
body = remove_not_given(
27+
{
28+
"prompt": prompt,
29+
"maxTokens": max_tokens,
30+
"numResults": num_results,
31+
"minTokens": min_tokens,
32+
"temperature": temperature,
33+
"topP": top_p,
34+
"topKReturn": top_k_return,
35+
"stopSequences": stop_sequences or [],
36+
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
37+
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
38+
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
39+
}
40+
)
4341

4442
raw_response = self._invoke(body)
4543

ai21/clients/studio/resources/studio_completion.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Optional, List
1+
from __future__ import annotations
2+
3+
from typing import List
24

35
from ai21.clients.common.completion_base import Completion
46
from ai21.clients.studio.resources.studio_resource import StudioResource
57
from ai21.models import Penalty, CompletionsResponse
8+
from ai21.types import NOT_GIVEN, NotGiven
69

710

811
class StudioCompletion(StudioResource, Completion):
@@ -11,23 +14,23 @@ def create(
1114
model: str,
1215
prompt: str,
1316
*,
14-
max_tokens: Optional[int] = None,
15-
num_results: Optional[int] = 1,
16-
min_tokens: Optional[int] = 0,
17-
temperature: Optional[float] = 0.7,
18-
top_p: Optional[float] = 1,
19-
top_k_return: Optional[int] = 0,
20-
custom_model: Optional[str] = None,
21-
stop_sequences: Optional[List[str]] = None,
22-
frequency_penalty: Optional[Penalty] = None,
23-
presence_penalty: Optional[Penalty] = None,
24-
count_penalty: Optional[Penalty] = None,
25-
epoch: Optional[int] = None,
17+
max_tokens: int | NotGiven = NOT_GIVEN,
18+
num_results: int | NotGiven = NOT_GIVEN,
19+
min_tokens: int | NotGiven = NOT_GIVEN,
20+
temperature: float | NotGiven = NOT_GIVEN,
21+
top_p: float | NotGiven = NOT_GIVEN,
22+
top_k_return: int | NotGiven = NOT_GIVEN,
23+
custom_model: str | NotGiven = NOT_GIVEN,
24+
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
25+
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
26+
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
27+
count_penalty: Penalty | NotGiven = NOT_GIVEN,
28+
epoch: int | NotGiven = NOT_GIVEN,
2629
**kwargs,
2730
) -> CompletionsResponse:
2831
url = f"{self._client.get_base_url()}/{model}"
2932

30-
if custom_model is not None:
33+
if custom_model:
3134
url = f"{url}/{custom_model}"
3235

3336
url = f"{url}/{self._module_name}"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from dataclasses_json import LetterCase, DataClassJsonMixin
22

3+
from ai21.utils.typing import is_not_given
4+
35

46
class AI21BaseModelMixin(DataClassJsonMixin):
57
dataclass_json_config = {
68
"letter_case": LetterCase.CAMEL,
9+
"exclude": is_not_given,
710
}

ai21/models/penalty.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
from __future__ import annotations
2+
13
from dataclasses import dataclass
2-
from typing import Optional
34

5+
from ai21.types import NOT_GIVEN, NotGiven
46
from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin
57

68

79
@dataclass
810
class Penalty(AI21BaseModelMixin):
911
scale: float
10-
apply_to_whitespaces: Optional[bool] = None
11-
apply_to_punctuation: Optional[bool] = None
12-
apply_to_numbers: Optional[bool] = None
13-
apply_to_stopwords: Optional[bool] = None
14-
apply_to_emojis: Optional[bool] = None
12+
apply_to_whitespaces: bool | NotGiven = NOT_GIVEN
13+
apply_to_punctuation: bool | NotGiven = NOT_GIVEN
14+
apply_to_numbers: bool | NotGiven = NOT_GIVEN
15+
apply_to_stopwords: bool | NotGiven = NOT_GIVEN
16+
apply_to_emojis: bool | NotGiven = NOT_GIVEN

ai21/types.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing_extensions import Literal
2+
3+
4+
# Sentinel class used until PEP 0661 is accepted
5+
class NotGiven:
6+
"""
7+
A sentinel singleton class used to distinguish omitted keyword arguments
8+
from those passed in with the value None (which may have different behavior).
9+
10+
For example:
11+
12+
```py
13+
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
14+
...
15+
16+
17+
get(timeout=1) # 1s timeout
18+
get(timeout=None) # No timeout
19+
get() # Default timeout behavior, which may not be statically known at the method definition.
20+
```
21+
"""
22+
23+
def __bool__(self) -> Literal[False]:
24+
return False
25+
26+
def __repr__(self) -> str:
27+
return "NOT_GIVEN"
28+
29+
30+
NOT_GIVEN = NotGiven()

ai21/utils/__init__.py

Whitespace-only changes.

ai21/utils/typing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Any, Dict
2+
3+
from ai21.types import NotGiven
4+
5+
6+
def is_not_given(value: Any) -> bool:
7+
return isinstance(value, NotGiven)
8+
9+
10+
def remove_not_given(body: Dict[str, Any]) -> Dict[str, Any]:
11+
return {k: v for k, v in body.items() if not is_not_given(v)}
12+
13+
14+
def to_camel_case(snake_str: str) -> str:
15+
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
16+
17+
18+
def to_lower_camel_case(snake_str: str) -> str:
19+
# We capitalize the first letter of each component except the first one
20+
# with the 'capitalize' method and join them together.
21+
camel_string = to_camel_case(snake_str)
22+
return snake_str[0].lower() + camel_string[1:]

0 commit comments

Comments
 (0)