Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion examples/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,21 @@ def test_create_thread_run(openai_mock: OpenAIMock):
model="gpt-4-turbo",
)

run = client.beta.threads.create_and_run(assistant_id=assistant.id)
run = client.beta.threads.create_and_run(
assistant_id=assistant.id,
thread={"messages": [{"role": "user", "content": "Hi"}]},
model=None,
instructions=None,
tools=None,
)

messages = client.beta.threads.messages.list(run.thread_id)

assert run.id
assert run.assistant_id == assistant.id
assert run.instructions == assistant.instructions
assert run.tools == assistant.tools
assert len(messages.data) == 1

assert openai_mock.beta.assistants.create.calls.call_count == 1
assert openai_mock.beta.threads.create_and_run.calls.call_count == 1
Expand Down
7 changes: 3 additions & 4 deletions src/openai_responses/_routes/assistants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any
from typing_extensions import override

Expand All @@ -20,7 +19,7 @@
)

from .._utils.faker import faker
from .._utils.serde import model_dict, model_parse
from .._utils.serde import json_loads, model_dict, model_parse
from .._utils.time import utcnow_unix_timestamp_s

__all__ = [
Expand Down Expand Up @@ -52,7 +51,7 @@ def _handler(self, request: httpx.Request, route: respx.Route) -> httpx.Response

@staticmethod
def _build(partial: PartialAssistant, request: httpx.Request) -> Assistant:
content = json.loads(request.content)
content = json_loads(request.content)
defaults: PartialAssistant = {
"id": faker.beta.assistant.id(),
"created_at": utcnow_unix_timestamp_s(),
Expand Down Expand Up @@ -152,7 +151,7 @@ def _handler(
if not found:
return httpx.Response(404)

content: AssistantUpdateParams = json.loads(request.content)
content: AssistantUpdateParams = json_loads(request.content)
deserialized = model_dict(found)
updated = model_parse(Assistant, deserialized | content)
self._state.beta.assistants.put(updated)
Expand Down
6 changes: 2 additions & 4 deletions src/openai_responses/_routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import httpx
import respx

Expand All @@ -10,7 +8,7 @@
from .._types.partials.chat import PartialChatCompletion

from .._utils.faker import faker
from .._utils.serde import model_parse
from .._utils.serde import json_loads, model_parse
from .._utils.time import utcnow_unix_timestamp_s

__all__ = ["ChatCompletionsCreateRoute"]
Expand All @@ -28,7 +26,7 @@ def _build(
partial: PartialChatCompletion,
request: httpx.Request,
) -> ChatCompletion:
content = json.loads(request.content)
content = json_loads(request.content)
defaults: PartialChatCompletion = {
"id": partial.get("id", faker.chat.completion.id()),
"created": partial.get("created", utcnow_unix_timestamp_s()),
Expand Down
6 changes: 2 additions & 4 deletions src/openai_responses/_routes/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from openai.types.embedding import Embedding
from openai.types.embedding_create_params import EmbeddingCreateParams
from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage
Expand All @@ -11,7 +9,7 @@

from .._types.partials.embeddings import PartialCreateEmbeddingResponse

from .._utils.serde import model_parse
from .._utils.serde import json_loads, model_parse

__all__ = ["EmbeddingsCreateRoute"]

Expand All @@ -33,7 +31,7 @@ def _build(
partial: PartialCreateEmbeddingResponse,
request: httpx.Request,
) -> CreateEmbeddingResponse:
content: EmbeddingCreateParams = json.loads(request.content)
content: EmbeddingCreateParams = json_loads(request.content)
embeddings = partial.get("data", [])
response = CreateEmbeddingResponse(
data=[model_parse(Embedding, e) for e in embeddings],
Expand Down
7 changes: 3 additions & 4 deletions src/openai_responses/_routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any
from typing_extensions import override

Expand All @@ -20,7 +19,7 @@
)

from .._utils.faker import faker
from .._utils.serde import model_dict, model_parse
from .._utils.serde import json_loads, model_dict, model_parse
from .._utils.time import utcnow_unix_timestamp_s

__all__ = [
Expand Down Expand Up @@ -65,7 +64,7 @@ def _handler(

@staticmethod
def _build(partial: PartialMessage, request: httpx.Request) -> Message:
content = json.loads(request.content)
content = json_loads(request.content)
defaults: PartialMessage = {
"id": faker.beta.thread.message.id(),
"content": [],
Expand Down Expand Up @@ -209,7 +208,7 @@ def _handler(
if not found_message:
return httpx.Response(404)

content: MessageUpdateParams = json.loads(request.content)
content: MessageUpdateParams = json_loads(request.content)
deserialized = model_dict(found_message)
updated = model_parse(Message, deserialized | content)
self._state.beta.threads.messages.put(updated)
Expand Down
19 changes: 13 additions & 6 deletions src/openai_responses/_routes/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

from ._base import StatefulRoute

from ..helpers.builders.messages import message_from_create_request
from ..helpers.builders.threads import thread_from_create_request

from .._stores import StateStore
from .._types.partials.runs import PartialRun, PartialRunList

from .._utils.copy import model_copy
from .._utils.faker import faker
from .._utils.serde import model_dict, model_parse
from .._utils.serde import json_loads, model_dict, model_parse
from .._utils.time import utcnow_unix_timestamp_s


Expand Down Expand Up @@ -59,7 +60,7 @@ def _handler(
if not found_thread:
return httpx.Response(404)

content: RunCreateParams = json.loads(request.content)
content: RunCreateParams = json_loads(request.content)

found_asst = self._state.beta.assistants.get(content["assistant_id"])
if not found_asst:
Expand All @@ -82,7 +83,7 @@ def _handler(

@staticmethod
def _build(partial: PartialRun, request: httpx.Request) -> Run:
content = json.loads(request.content)
content = json_loads(request.content)
defaults: PartialRun = {
"id": faker.beta.thread.run.id(),
"created_at": utcnow_unix_timestamp_s(),
Expand All @@ -105,7 +106,7 @@ def __init__(self, router: respx.MockRouter, state: StateStore) -> None:
def _handler(self, request: httpx.Request, route: respx.Route) -> httpx.Response:
self._route = route

content: ThreadCreateAndRunParams = json.loads(request.content)
content: ThreadCreateAndRunParams = json_loads(request.content)

found_asst = self._state.beta.assistants.get(content["assistant_id"])
if not found_asst:
Expand All @@ -117,6 +118,12 @@ def _handler(self, request: httpx.Request, route: respx.Route) -> httpx.Response
thread = thread_from_create_request(thread_create_req)
self._state.beta.threads.put(thread)

for message_create_params in thread_create_params.get("messages", []):
encoded = json.dumps(message_create_params).encode("utf-8")
create_message_req = httpx.Request(method="", url="", content=encoded)
message = message_from_create_request(thread.id, create_message_req)
self._state.beta.threads.messages.put(message)

model = self._build(
{
"thread_id": thread.id,
Expand All @@ -134,7 +141,7 @@ def _handler(self, request: httpx.Request, route: respx.Route) -> httpx.Response

@staticmethod
def _build(partial: PartialRun, request: httpx.Request) -> Run:
content = json.loads(request.content)
content = json_loads(request.content)
if content.get("thread"):
del content["thread"]

Expand Down Expand Up @@ -260,7 +267,7 @@ def _handler(
if not found_run:
return httpx.Response(404)

content: RunUpdateParams = json.loads(request.content)
content: RunUpdateParams = json_loads(request.content)
deserialized = model_dict(found_run)
updated = model_parse(Run, deserialized | content)
self._state.beta.threads.runs.put(updated)
Expand Down
15 changes: 7 additions & 8 deletions src/openai_responses/_routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .._types.partials.threads import PartialThread, PartialThreadDeleted

from .._utils.faker import faker
from .._utils.serde import model_dict, model_parse
from .._utils.serde import json_loads, model_dict, model_parse
from .._utils.time import utcnow_unix_timestamp_s


Expand All @@ -42,16 +42,15 @@ def __init__(self, router: respx.MockRouter, state: StateStore) -> None:
def _handler(self, request: httpx.Request, route: respx.Route) -> httpx.Response:
self._route = route

content: ThreadCreateParams = json.loads(request.content)
content: ThreadCreateParams = json_loads(request.content)
model = self._build({}, request)
self._state.beta.threads.put(model)

if content.get("messages"):
for message_create_params in content.get("messages", []):
encoded = json.dumps(message_create_params).encode("utf-8")
create_message_req = httpx.Request(method="", url="", content=encoded)
message = message_from_create_request(model.id, create_message_req)
self._state.beta.threads.messages.put(message)
for message_create_params in content.get("messages", []):
encoded = json.dumps(message_create_params).encode("utf-8")
create_message_req = httpx.Request(method="", url="", content=encoded)
message = message_from_create_request(model.id, create_message_req)
self._state.beta.threads.messages.put(message)

return httpx.Response(
status_code=self._status_code,
Expand Down
8 changes: 8 additions & 0 deletions src/openai_responses/_utils/serde.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import json
from typing import Any, Type

from openai import BaseModel

from .._types.generics import M

__all__ = ["json_loads", "model_dict", "model_parse"]


def json_loads(b: bytes) -> Any:
d = json.loads(b)
return {k: v for k, v in d.items() if v is not None}


def model_dict(m: BaseModel) -> dict[str, Any]:
if hasattr(m, "model_dump"):
Expand Down