Skip to content

Commit 29fb8ce

Browse files
authored
Merge pull request #232 from AI21Labs/threads
feat: ✨ add support for Thread resource
2 parents 2cf1a85 + c944172 commit 29fb8ce

File tree

11 files changed

+142
-55
lines changed

11 files changed

+142
-55
lines changed

ai21/clients/common/assistant/assistant.py renamed to ai21/clients/common/assistant/assistants.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,13 @@
33
from abc import ABC, abstractmethod
44
from typing import Any, Dict, List
55

6-
from ai21.models.responses.assistant_response import (
7-
AssistantResponse,
8-
Optimization,
9-
ToolResources,
10-
Tool,
11-
ListAssistantResponse,
12-
)
6+
from ai21.models.assistant.assistant import Optimization, Tool, ToolResources
7+
from ai21.models.responses.assistant_response import Assistant, ListAssistant
138
from ai21.types import NotGiven, NOT_GIVEN
149
from ai21.utils.typing import remove_not_given
1510

1611

17-
class Assistant(ABC):
12+
class Assistants(ABC):
1813
_module_name = "assistants"
1914

2015
@abstractmethod
@@ -29,7 +24,7 @@ def create(
2924
tools: List[Tool] | NotGiven = NOT_GIVEN,
3025
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
3126
**kwargs,
32-
) -> AssistantResponse:
27+
) -> Assistant:
3328
pass
3429

3530
def _create_body(
@@ -58,11 +53,11 @@ def _create_body(
5853
)
5954

6055
@abstractmethod
61-
def list(self) -> ListAssistantResponse:
56+
def list(self) -> ListAssistant:
6257
pass
6358

6459
@abstractmethod
65-
def get(self, assistant_id: str) -> AssistantResponse:
60+
def get(self, assistant_id: str) -> Assistant:
6661
pass
6762

6863
@abstractmethod
@@ -78,5 +73,5 @@ def modify(
7873
models: List[str] | NotGiven = NOT_GIVEN,
7974
tools: List[Tool] | NotGiven = NOT_GIVEN,
8075
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
81-
) -> AssistantResponse:
76+
) -> Assistant:
8277
pass
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import List
5+
6+
from ai21.models.assistant.message import Message
7+
from ai21.models.responses.thread_response import Thread
8+
9+
10+
class Threads(ABC):
11+
_module_name = "threads"
12+
13+
@abstractmethod
14+
def create(self, messages: List[Message], **kwargs) -> Thread:
15+
pass
16+
17+
@abstractmethod
18+
def get(self, thread_id: str) -> Thread:
19+
pass

ai21/clients/studio/resources/assistant/studio_assistant.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,17 @@
22

33
from typing import List
44

5-
from ai21.clients.common.assistant.assistant import Assistant
5+
from ai21.clients.common.assistant.assistants import Assistants
66
from ai21.clients.studio.resources.studio_resource import (
77
AsyncStudioResource,
88
StudioResource,
99
)
10-
from ai21.models.responses.assistant_response import (
11-
AssistantResponse,
12-
Tool,
13-
ToolResources,
14-
ListAssistantResponse,
15-
)
10+
from ai21.models.assistant.assistant import Tool, ToolResources
11+
from ai21.models.responses.assistant_response import Assistant, ListAssistant
1612
from ai21.types import NotGiven, NOT_GIVEN
1713

1814

19-
class StudioAssistant(StudioResource, Assistant):
15+
class StudioAssistant(StudioResource, Assistants):
2016
def create(
2117
self,
2218
name: str,
@@ -28,7 +24,7 @@ def create(
2824
tools: List[Tool] | NotGiven = NOT_GIVEN,
2925
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
3026
**kwargs,
31-
) -> AssistantResponse:
27+
) -> Assistant:
3228
body = self._create_body(
3329
name=name,
3430
description=description,
@@ -40,13 +36,13 @@ def create(
4036
**kwargs,
4137
)
4238

43-
return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)
39+
return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)
4440

45-
def get(self, assistant_id: str) -> AssistantResponse:
46-
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)
41+
def get(self, assistant_id: str) -> Assistant:
42+
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)
4743

48-
def list(self) -> ListAssistantResponse:
49-
return self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse)
44+
def list(self) -> ListAssistant:
45+
return self._get(path=f"/{self._module_name}", response_cls=ListAssistant)
5046

5147
def modify(
5248
self,
@@ -60,7 +56,7 @@ def modify(
6056
models: List[str] | NotGiven = NOT_GIVEN,
6157
tools: List[Tool] | NotGiven = NOT_GIVEN,
6258
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
63-
) -> AssistantResponse:
59+
) -> Assistant:
6460
body = self._create_body(
6561
name=name,
6662
description=description,
@@ -72,10 +68,10 @@ def modify(
7268
tool_resources=tool_resources,
7369
)
7470

75-
return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse)
71+
return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant)
7672

7773

78-
class AsyncStudioAssistant(AsyncStudioResource, Assistant):
74+
class AsyncStudioAssistant(AsyncStudioResource, Assistants):
7975
async def create(
8076
self,
8177
name: str,
@@ -87,7 +83,7 @@ async def create(
8783
tools: List[Tool] | NotGiven = NOT_GIVEN,
8884
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
8985
**kwargs,
90-
) -> AssistantResponse:
86+
) -> Assistant:
9187
body = self._create_body(
9288
name=name,
9389
description=description,
@@ -99,13 +95,13 @@ async def create(
9995
**kwargs,
10096
)
10197

102-
return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)
98+
return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)
10399

104-
async def get(self, assistant_id: str) -> AssistantResponse:
105-
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)
100+
async def get(self, assistant_id: str) -> Assistant:
101+
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)
106102

107-
async def list(self) -> ListAssistantResponse:
108-
return await self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse)
103+
async def list(self) -> ListAssistant:
104+
return await self._get(path=f"/{self._module_name}", response_cls=ListAssistant)
109105

110106
async def modify(
111107
self,
@@ -119,7 +115,7 @@ async def modify(
119115
models: List[str] | NotGiven = NOT_GIVEN,
120116
tools: List[Tool] | NotGiven = NOT_GIVEN,
121117
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
122-
) -> AssistantResponse:
118+
) -> Assistant:
123119
body = self._create_body(
124120
name=name,
125121
description=description,
@@ -131,4 +127,4 @@ async def modify(
131127
tool_resources=tool_resources,
132128
)
133129

134-
return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse)
130+
return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
from typing import List
4+
5+
from ai21.clients.common.assistant.threads import Threads
6+
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
7+
from ai21.models.assistant.message import Message
8+
from ai21.models.responses.thread_response import Thread
9+
10+
11+
class StudioThread(StudioResource, Threads):
12+
def create(self, messages: List[Message], **kwargs) -> Thread:
13+
body = dict(messages=messages)
14+
15+
return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread)
16+
17+
def get(self, thread_id: str) -> Thread:
18+
return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread)
19+
20+
21+
class AsyncStudioThread(AsyncStudioResource, Threads):
22+
async def create(self, messages: List[Message], **kwargs) -> Thread:
23+
body = dict(messages=messages)
24+
25+
return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread)
26+
27+
async def get(self, thread_id: str) -> Thread:
28+
return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread)

ai21/clients/studio/resources/beta/async_beta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ai21.clients.studio.resources.assistant.studio_assistant import AsyncStudioAssistant
2+
from ai21.clients.studio.resources.assistant.studio_thread import AsyncStudioThread
23
from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag
34
from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
45
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
@@ -8,5 +9,6 @@ class AsyncBeta(AsyncStudioResource):
89
def __init__(self, client: AsyncAI21HTTPClient):
910
super().__init__(client)
1011

11-
self.conversational_rag = AsyncStudioConversationalRag(client)
1212
self.assistants = AsyncStudioAssistant(client)
13+
self.conversational_rag = AsyncStudioConversationalRag(client)
14+
self.threads = AsyncStudioThread(client)

ai21/clients/studio/resources/beta/beta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ai21.clients.studio.resources.assistant.studio_assistant import StudioAssistant
2+
from ai21.clients.studio.resources.assistant.studio_thread import StudioThread
23
from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag
34
from ai21.clients.studio.resources.studio_resource import StudioResource
45
from ai21.http_client.http_client import AI21HTTPClient
@@ -8,5 +9,6 @@ class Beta(StudioResource):
89
def __init__(self, client: AI21HTTPClient):
910
super().__init__(client)
1011

11-
self.conversational_rag = StudioConversationalRag(client)
1212
self.assistants = StudioAssistant(client)
13+
self.conversational_rag = StudioConversationalRag(client)
14+
self.threads = StudioThread(client)

ai21/models/assistant/__init__.py

Whitespace-only changes.

ai21/models/assistant/assistant.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Optional, Literal
2+
3+
from typing_extensions import TypedDict
4+
5+
Optimization = Literal["cost", "latency"]
6+
Tool = Literal["rag", "internet_research", "plan_approval"]
7+
8+
9+
class ToolResources(TypedDict, total=False):
10+
rag: Optional[dict]
11+
internet_research: Optional[dict]
12+
plan_approval: Optional[dict]

ai21/models/assistant/message.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from datetime import datetime
2+
from typing import Literal, Optional
3+
4+
from typing_extensions import TypedDict
5+
6+
from ai21.models.ai21_base_model import AI21BaseModel
7+
8+
ThreadMessageRole = Literal["assistant", "user"]
9+
10+
11+
class MessageContentText(TypedDict):
12+
type: Literal["text"]
13+
text: str
14+
15+
16+
class Message(TypedDict):
17+
role: ThreadMessageRole
18+
content: MessageContentText
19+
20+
21+
class MessageResponse(AI21BaseModel):
22+
id: str
23+
created_at: datetime
24+
updated_at: datetime
25+
object: Literal["message"] = "message"
26+
role: ThreadMessageRole
27+
content: MessageContentText
28+
run_id: Optional[str] = None
29+
assistant_id: Optional[str] = None
Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,15 @@
11
from datetime import datetime
22
from typing import Optional, List, Literal
33

4-
from typing_extensions import TypedDict
5-
64
from ai21.models.ai21_base_model import AI21BaseModel
5+
from ai21.models.assistant.assistant import ToolResources
76

87

9-
Optimization = Literal["cost", "latency"]
10-
Tool = Literal["rag", "internet_research", "plan_approval"]
11-
12-
13-
class ToolResources(TypedDict, total=False):
14-
rag: Optional[dict]
15-
internet_research: Optional[dict]
16-
plan_approval: Optional[dict]
17-
18-
19-
class AssistantResponse(AI21BaseModel):
8+
class Assistant(AI21BaseModel):
209
id: str
2110
created_at: datetime
2211
updated_at: datetime
23-
object: str
12+
object: Literal["assistant"] = "assistant"
2413
name: str
2514
description: Optional[str] = None
2615
optimization: str
@@ -33,5 +22,5 @@ class AssistantResponse(AI21BaseModel):
3322
tool_resources: Optional[ToolResources] = None
3423

3524

36-
class ListAssistantResponse(AI21BaseModel):
37-
results: List[AssistantResponse]
25+
class ListAssistant(AI21BaseModel):
26+
results: List[Assistant]

0 commit comments

Comments
 (0)