Skip to content

Commit e849d78

Browse files
authored
fix: Pass default model_id in bedrock client (#72)
* fix: Pass default model_id in bedrock client * fix: Added validation for model_id * fix: Added validation for model_id
1 parent 409e818 commit e849d78

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

ai21/clients/bedrock/ai21_bedrock_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class AI21BedrockClient:
1515

1616
def __init__(
1717
self,
18-
model_id: str,
18+
model_id: Optional[str] = None,
1919
session: Optional[boto3.Session] = None,
2020
region: Optional[str] = None,
2121
env_config: _AI21EnvConfig = AI21EnvConfig,

ai21/clients/bedrock/resources/bedrock_completion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def create(
4343

4444
model_id = kwargs.get("model_id", self._model_id)
4545

46+
if model_id is None:
47+
raise ValueError("model_id should be provided in either the constructor or the 'create' method call")
48+
4649
raw_response = self._invoke(model_id=model_id, body=body)
4750

4851
return CompletionsResponse.from_dict(raw_response)

tests/integration_tests/clients/bedrock/test_completion.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,12 @@ def test_completion__when_no_penalties__should_return_response(
6868
assert len([completion.data.text for completion in response.completions]) == 1
6969
for completion in response.completions:
7070
assert isinstance(completion.data.text, str)
71+
72+
73+
@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.")
74+
def test_completion__when_no_model_id__should_raise_exception():
75+
with pytest.raises(ValueError) as e:
76+
client = AI21BedrockClient()
77+
client.completion.create(prompt=_PROMPT)
78+
79+
assert e.value.args[0] == "model_id should be provided in either the constructor or the 'create' method call"

0 commit comments

Comments
 (0)