Skip to content

Commit 23772dc

Browse files
benshukbenshuk
andauthored
feat: ✨ add support for str message content (#245)
* feat: ✨ add support for `str` message content * fix: 🏷️ typing * fix: 🏷️ typing again * fix: 🏷️ typing thingy --------- Co-authored-by: benshuk <bens@ai21.com>
1 parent 21f6df3 commit 23772dc

File tree

7 files changed

+33
-29
lines changed

7 files changed

+33
-29
lines changed

ai21/clients/common/beta/assistant/messages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44

5-
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
5+
from ai21.models.assistant.message import ThreadMessageRole, ThreadMessageContent
66
from ai21.models.responses.message_response import MessageResponse, ListMessageResponse
77

88

@@ -15,7 +15,7 @@ def create(
1515
thread_id: str,
1616
*,
1717
role: ThreadMessageRole,
18-
content: MessageContentText,
18+
content: ThreadMessageContent,
1919
**kwargs,
2020
) -> MessageResponse:
2121
pass

ai21/clients/studio/resources/beta/assistant/thread.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
99
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
1010
from ai21.http_client.http_client import AI21HTTPClient
11-
from ai21.models.assistant.message import Message
11+
from ai21.models.assistant.message import Message, modify_message_content
1212
from ai21.models.responses.thread_response import ThreadResponse
1313

1414

@@ -24,7 +24,7 @@ def create(
2424
messages: List[Message],
2525
**kwargs,
2626
) -> ThreadResponse:
27-
body = dict(messages=messages)
27+
body = dict(messages=[modify_message_content(message) for message in messages])
2828

2929
return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)
3030

@@ -44,7 +44,7 @@ async def create(
4444
messages: List[Message],
4545
**kwargs,
4646
) -> ThreadResponse:
47-
body = dict(messages=messages)
47+
body = dict(messages=[modify_message_content(message) for message in messages])
4848

4949
return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)
5050

ai21/clients/studio/resources/beta/assistant/thread_messages.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
34
from ai21.clients.common.beta.assistant.messages import BaseMessages
45
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
5-
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
6+
from ai21.models.assistant.message import ThreadMessageRole, modify_message_content, Message, ThreadMessageContent
67
from ai21.models.responses.message_response import MessageResponse, ListMessageResponse
78

89

@@ -12,13 +13,10 @@ def create(
1213
thread_id: str,
1314
*,
1415
role: ThreadMessageRole,
15-
content: MessageContentText,
16+
content: ThreadMessageContent,
1617
**kwargs,
1718
) -> MessageResponse:
18-
body = dict(
19-
role=role,
20-
content=content,
21-
)
19+
body = modify_message_content(Message(role=role, content=content))
2220

2321
return self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse)
2422

@@ -32,13 +30,10 @@ async def create(
3230
thread_id: str,
3331
*,
3432
role: ThreadMessageRole,
35-
content: MessageContentText,
33+
content: ThreadMessageContent,
3634
**kwargs,
3735
) -> MessageResponse:
38-
body = dict(
39-
role=role,
40-
content=content,
41-
)
36+
body = modify_message_content(Message(role=role, content=content))
4237

4338
return await self._post(
4439
path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse

ai21/models/assistant/message.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
1-
from typing import Literal
1+
from __future__ import annotations
2+
from typing import Literal, Union
23

34
from typing_extensions import TypedDict
45

6+
57
ThreadMessageRole = Literal["assistant", "user"]
68

79

8-
class MessageContentText(TypedDict):
10+
class ThreadMessageContentText(TypedDict):
911
type: Literal["text"]
1012
text: str
1113

1214

15+
ThreadMessageContent = Union[str, ThreadMessageContentText]
16+
17+
1318
class Message(TypedDict):
1419
role: ThreadMessageRole
15-
content: MessageContentText
20+
content: ThreadMessageContent
21+
22+
23+
def modify_message_content(message: Message) -> Message:
24+
role = message["role"]
25+
content = message["content"]
26+
27+
if isinstance(content, str):
28+
content = ThreadMessageContentText(type="text", text=content)
29+
30+
return Message(role=role, content=content)

ai21/models/responses/message_response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Literal, Optional, List
33

44
from ai21.models.ai21_base_model import AI21BaseModel
5-
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
5+
from ai21.models.assistant.message import ThreadMessageRole, ThreadMessageContentText
66

77

88
class MessageResponse(AI21BaseModel):
@@ -11,7 +11,7 @@ class MessageResponse(AI21BaseModel):
1111
updated_at: datetime
1212
object: Literal["message"] = "message"
1313
role: ThreadMessageRole
14-
content: MessageContentText
14+
content: ThreadMessageContentText
1515
run_id: Optional[str] = None
1616
assistant_id: Optional[str] = None
1717

examples/studio/assistant/assistant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ def main():
1414
messages=[
1515
{
1616
"role": "user",
17-
"content": {
18-
"type": "text",
19-
"text": "Hello",
20-
},
17+
"content": "Hello",
2118
},
2219
]
2320
)

examples/studio/assistant/async_assistant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ async def main():
1616
messages=[
1717
{
1818
"role": "user",
19-
"content": {
20-
"type": "text",
21-
"text": "Hello",
22-
},
19+
"content": "Hello",
2320
},
2421
]
2522
)

0 commit comments

Comments
 (0)