Skip to content

Commit b0b5bc1

Browse files
authored
fix: add logit bias to studio, sagemaker (#70)
* fix: add logit bias, fix studio completion * fix: adjust tests * fix: add logit bias integration test * fix: update studio completion example * fix: fix studio completion example * fix: fix studio completion example * fix: remove logit bias from bedrock * fix: add logit bias to sagemaker completion, add params string * fix: adjust tests * fix: add logit bias integration test * fix: update studio completion example * fix: fix studio completion example * fix: update code with new not_giving approach
1 parent 5e9a768 commit b0b5bc1

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

ai21/clients/common/completion_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def create(
2929
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
3030
count_penalty: Penalty | NotGiven = NOT_GIVEN,
3131
epoch: int | NotGiven = NOT_GIVEN,
32+
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
3233
**kwargs,
3334
) -> CompletionsResponse:
3435
"""
@@ -46,6 +47,9 @@ def create(
4647
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
4748
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
4849
:param epoch:
50+
:param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
51+
representations of the tokens and the floats are the biases themselves. A positive bias increases generation
52+
probability for a given token and a negative bias decreases it.
4953
:param kwargs:
5054
:return:
5155
"""
@@ -70,6 +74,7 @@ def _create_body(
7074
presence_penalty: Penalty | NotGiven,
7175
count_penalty: Penalty | NotGiven,
7276
epoch: int | NotGiven,
77+
logit_bias: Dict[str, float] | NotGiven,
7378
):
7479
return remove_not_given(
7580
{
@@ -87,5 +92,6 @@ def _create_body(
8792
"presencePenalty": NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty.to_dict(),
8893
"countPenalty": NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty.to_dict(),
8994
"epoch": epoch,
95+
"logitBias": logit_bias,
9096
}
9197
)

ai21/clients/sagemaker/resources/sagemaker_completion.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Dict
22

33
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
44
from ai21.models import Penalty, CompletionsResponse
@@ -21,8 +21,27 @@ def create(
2121
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
2222
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
2323
count_penalty: Penalty | NotGiven = NOT_GIVEN,
24+
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
2425
**kwargs,
2526
) -> CompletionsResponse:
27+
"""
28+
:param prompt: Text for model to complete
29+
:param max_tokens: The maximum number of tokens to generate per result
30+
:param num_results: Number of completions to sample and return.
31+
:param min_tokens: The minimum number of tokens to generate per result.
32+
:param temperature: A value controlling the "creativity" of the model's responses.
33+
:param top_p: A value controlling the diversity of the model's responses.
34+
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
35+
:param stop_sequences: Stops decoding if any of the strings is generated
36+
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
37+
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
38+
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
39+
:param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
40+
representations of the tokens and the floats are the biases themselves. A positive bias increases generation
41+
probability for a given token and a negative bias decreases it.
42+
:param kwargs:
43+
:return:
44+
"""
2645
body = remove_not_given(
2746
{
2847
"prompt": prompt,
@@ -36,6 +55,7 @@ def create(
3655
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
3756
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
3857
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
58+
"logitBias": logit_bias,
3959
}
4060
)
4161

ai21/clients/studio/resources/studio_completion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import List
3+
from typing import List, Dict
44

55
from ai21.clients.common.completion_base import Completion
66
from ai21.clients.studio.resources.studio_resource import StudioResource
@@ -26,6 +26,7 @@ def create(
2626
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
2727
count_penalty: Penalty | NotGiven = NOT_GIVEN,
2828
epoch: int | NotGiven = NOT_GIVEN,
29+
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
2930
**kwargs,
3031
) -> CompletionsResponse:
3132
url = f"{self._client.get_base_url()}/{model}"
@@ -49,5 +50,6 @@ def create(
4950
presence_penalty=presence_penalty,
5051
count_penalty=count_penalty,
5152
epoch=epoch,
53+
logit_bias=logit_bias,
5254
)
5355
return self._json_to_response(self._post(url=url, body=body))

examples/studio/completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
num_results=1,
4545
custom_model=None,
4646
epoch=1,
47+
logit_bias={"▁I'm▁sorry": -100.0},
4748
count_penalty=Penalty(
4849
scale=0,
4950
apply_to_emojis=False,

tests/integration_tests/clients/studio/test_completion.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from typing import Dict
34
from ai21 import AI21Client
45
from ai21.models import Penalty
56

@@ -25,6 +26,7 @@ def test_completion():
2526
num_results=num_results,
2627
custom_model=None,
2728
epoch=1,
29+
logit_bias={"▁a▁box▁of": -100.0},
2830
count_penalty=Penalty(
2931
scale=0,
3032
apply_to_emojis=False,
@@ -110,3 +112,24 @@ def test_completion_when_finish_reason_defined__should_halt_on_expected_reason(
110112
)
111113

112114
assert response.completions[0].finish_reason.reason == reason
115+
116+
117+
@pytest.mark.parametrize(
118+
ids=[
119+
"no_logit_bias",
120+
"logit_bias_negative",
121+
],
122+
argnames=["expected_result", "logit_bias"],
123+
argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})],
124+
)
125+
def test_completion_logit_bias__should_impact_on_response(expected_result: str, logit_bias: Dict[str, float]):
126+
client = AI21Client()
127+
response = client.completion.create(
128+
prompt="Life is like",
129+
max_tokens=3,
130+
model="j2-ultra",
131+
temperature=0,
132+
logit_bias=logit_bias,
133+
)
134+
135+
assert response.completions[0].data.text.strip() == expected_result.strip()

0 commit comments

Comments
 (0)