Skip to content

Commit d7c912f

Browse files
fix: penalties in Sagemaker and Bedrock (#67)
* fix: penalties in sagemaker * fix: don't pass None penalties to Bedrock * fix: remove some default arge, and some unused args from bedrock model * test: Added bedrock integration tests for penalties check * ci: Integration tests on push * fix: answer test --------- Co-authored-by: etang <etang@ai21.com>
1 parent 8b0f217 commit d7c912f

File tree

7 files changed

+92
-43
lines changed

7 files changed

+92
-43
lines changed

.github/workflows/integration-tests.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
name: Integration Tests
22

3-
on:
4-
push:
5-
branches:
6-
- main
7-
- "rc_*"
3+
on: [push]
84

95
env:
106
POETRY_VERSION: "1.7.1"

ai21/clients/bedrock/resources/bedrock_completion.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,17 @@ def create(
3131
"topP": top_p,
3232
"topKReturn": top_k_return,
3333
"stopSequences": stop_sequences or [],
34-
"frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(),
35-
"presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(),
36-
"countPenalty": None if count_penalty is None else count_penalty.to_dict(),
3734
}
35+
36+
if frequency_penalty is not None:
37+
body["frequencyPenalty"] = frequency_penalty.to_dict()
38+
39+
if presence_penalty is not None:
40+
body["presencePenalty"] = presence_penalty.to_dict()
41+
42+
if count_penalty is not None:
43+
body["countPenalty"] = count_penalty.to_dict()
44+
3845
raw_response = self._invoke(model_id=model_id, body=body)
3946

4047
return CompletionsResponse.from_dict(raw_response)

ai21/clients/sagemaker/resources/sagemaker_completion.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,17 @@ def create(
3030
"topP": top_p,
3131
"topKReturn": top_k_return,
3232
"stopSequences": stop_sequences or [],
33-
"frequencyPenalty": None if frequency_penalty is None else frequency_penalty.to_dict(),
34-
"presencePenalty": None if presence_penalty is None else presence_penalty.to_dict(),
35-
"countPenalty": None if count_penalty is None else count_penalty.to_dict(),
3633
}
34+
35+
if frequency_penalty is not None:
36+
body["frequencyPenalty"] = frequency_penalty.to_dict()
37+
38+
if presence_penalty is not None:
39+
body["presencePenalty"] = presence_penalty.to_dict()
40+
41+
if count_penalty is not None:
42+
body["countPenalty"] = count_penalty.to_dict()
43+
3744
raw_response = self._invoke(body)
3845

3946
return CompletionsResponse.from_dict(raw_response)

examples/bedrock/completion.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from ai21 import AI21BedrockClient, BedrockModelID
2-
from ai21.models import Penalty
32

43
# Bedrock is currently supported only in us-east-1 region.
54
# Either set your profile's region to us-east-1 or uncomment next line
@@ -46,34 +45,6 @@
4645
temperature=0,
4746
top_p=1,
4847
top_k_return=0,
49-
stop_sequences=["##"],
50-
num_results=1,
51-
custom_model=None,
52-
epoch=1,
53-
count_penalty=Penalty(
54-
scale=0,
55-
apply_to_emojis=False,
56-
apply_to_numbers=False,
57-
apply_to_stopwords=False,
58-
apply_to_punctuation=False,
59-
apply_to_whitespaces=False,
60-
),
61-
frequency_penalty=Penalty(
62-
scale=0,
63-
apply_to_emojis=False,
64-
apply_to_numbers=False,
65-
apply_to_stopwords=False,
66-
apply_to_punctuation=False,
67-
apply_to_whitespaces=False,
68-
),
69-
presence_penalty=Penalty(
70-
scale=0,
71-
apply_to_emojis=False,
72-
apply_to_numbers=False,
73-
apply_to_stopwords=False,
74-
apply_to_punctuation=False,
75-
apply_to_whitespaces=False,
76-
),
7748
)
7849

7950
print(response.completions[0].data.text)

examples/studio/answer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from ai21 import AI21Client
2-
from ai21.models import Mode, AnswerLength
32

43

54
client = AI21Client()
@@ -10,7 +9,5 @@
109
"ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and "
1110
"economic power, dominating the other provinces of the newly independent Dutch Republic.",
1211
question="When did Holland become an economic power?",
13-
answer_length=AnswerLength.LONG,
14-
mode=Mode.FLEXIBLE,
1512
)
1613
print(response)

tests/integration_tests/clients/bedrock/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Optional
2+
3+
import pytest
4+
5+
from ai21 import AI21BedrockClient
6+
from ai21.clients.bedrock.bedrock_model_id import BedrockModelID
7+
from ai21.models import Penalty
8+
from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests
9+
10+
_PROMPT = "Once upon a time, in a land far, far away, there was a"
11+
12+
13+
@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.")
14+
@pytest.mark.parametrize(
15+
ids=[
16+
"when_no_penalties__should_return_response",
17+
"when_penalties__should_return_response",
18+
],
19+
argnames=["frequency_penalty", "presence_penalty", "count_penalty"],
20+
argvalues=[
21+
(None, None, None),
22+
(
23+
Penalty(
24+
scale=0.5,
25+
apply_to_emojis=True,
26+
apply_to_numbers=True,
27+
apply_to_stopwords=True,
28+
apply_to_punctuation=True,
29+
apply_to_whitespaces=True,
30+
),
31+
Penalty(
32+
scale=0.5,
33+
apply_to_emojis=True,
34+
apply_to_numbers=True,
35+
apply_to_stopwords=True,
36+
apply_to_punctuation=True,
37+
apply_to_whitespaces=True,
38+
),
39+
Penalty(
40+
scale=0.5,
41+
apply_to_emojis=True,
42+
apply_to_numbers=True,
43+
apply_to_stopwords=True,
44+
apply_to_punctuation=True,
45+
apply_to_whitespaces=True,
46+
),
47+
),
48+
],
49+
)
50+
def test_completion__when_no_penalties__should_return_response(
51+
frequency_penalty: Optional[Penalty], presence_penalty: Optional[Penalty], count_penalty: Optional[Penalty]
52+
):
53+
client = AI21BedrockClient()
54+
response = client.completion.create(
55+
prompt=_PROMPT,
56+
max_tokens=64,
57+
model_id=BedrockModelID.J2_MID_V1,
58+
temperature=0,
59+
top_p=1,
60+
top_k_return=0,
61+
frequency_penalty=frequency_penalty,
62+
presence_penalty=presence_penalty,
63+
count_penalty=count_penalty,
64+
)
65+
66+
assert response.prompt.text == _PROMPT
67+
assert len(response.completions) == 1
68+
# Check the results aren't all the same
69+
assert len([completion.data.text for completion in response.completions]) == 1
70+
for completion in response.completions:
71+
assert isinstance(completion.data.text, str)

0 commit comments

Comments
 (0)