Skip to content

Commit

Permalink
chore(internal): split up transforms into sync / async (#1210)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Mar 13, 2024
1 parent 1879c97 commit 31bfc12
Show file tree
Hide file tree
Showing 17 changed files with 363 additions and 111 deletions.
2 changes: 2 additions & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,7 @@
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
async_transform as async_transform,
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
128 changes: 123 additions & 5 deletions src/openai/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,7 @@ def _transform_recursive(
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)

return _transform_value(data, annotation)


def _transform_value(data: object, type_: type) -> object:
annotated_type = _get_annotated_type(type_)
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data

Expand Down Expand Up @@ -222,3 +218,125 @@ def _transform_typeddict(
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result


async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)


async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)


async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation

stripped_type = strip_annotated_type(inner_type)
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)

if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)

annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data

# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)

return data


async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()

if format_ == "custom" and format_template is not None:
return data.strftime(format_template)

return data


async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result
7 changes: 5 additions & 2 deletions src/openai/resources/audio/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from ... import _legacy_response
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import maybe_transform
from ..._utils import (
maybe_transform,
async_maybe_transform,
)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
Expand Down Expand Up @@ -161,7 +164,7 @@ async def create(
extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})}
return await self._post(
"/audio/speech",
body=maybe_transform(
body=await async_maybe_transform(
{
"input": input,
"model": model,
Expand Down
9 changes: 7 additions & 2 deletions src/openai/resources/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from ... import _legacy_response
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
from ..._utils import (
extract_files,
maybe_transform,
deepcopy_minimal,
async_maybe_transform,
)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -200,7 +205,7 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/audio/transcriptions",
body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
body=await async_maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
Expand Down
9 changes: 7 additions & 2 deletions src/openai/resources/audio/translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from ... import _legacy_response
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
from ..._utils import (
extract_files,
maybe_transform,
deepcopy_minimal,
async_maybe_transform,
)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -174,7 +179,7 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/audio/translations",
body=maybe_transform(body, translation_create_params.TranslationCreateParams),
body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
Expand Down
9 changes: 6 additions & 3 deletions src/openai/resources/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
AsyncFilesWithStreamingResponse,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._utils import (
maybe_transform,
async_maybe_transform,
)
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -410,7 +413,7 @@ async def create(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
"/assistants",
body=maybe_transform(
body=await async_maybe_transform(
{
"model": model,
"description": description,
Expand Down Expand Up @@ -525,7 +528,7 @@ async def update(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/assistants/{assistant_id}",
body=maybe_transform(
body=await async_maybe_transform(
{
"description": description,
"file_ids": file_ids,
Expand Down
7 changes: 5 additions & 2 deletions src/openai/resources/beta/assistants/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

from .... import _legacy_response
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._utils import (
maybe_transform,
async_maybe_transform,
)
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -259,7 +262,7 @@ async def create(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/assistants/{assistant_id}/files",
body=maybe_transform({"file_id": file_id}, file_create_params.FileCreateParams),
body=await async_maybe_transform({"file_id": file_id}, file_create_params.FileCreateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
Expand Down
9 changes: 6 additions & 3 deletions src/openai/resources/beta/threads/messages/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
AsyncFilesWithStreamingResponse,
)
from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ....._utils import maybe_transform
from ....._utils import (
maybe_transform,
async_maybe_transform,
)
from ....._compat import cached_property
from ....._resource import SyncAPIResource, AsyncAPIResource
from ....._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -315,7 +318,7 @@ async def create(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/threads/{thread_id}/messages",
body=maybe_transform(
body=await async_maybe_transform(
{
"content": content,
"role": role,
Expand Down Expand Up @@ -404,7 +407,7 @@ async def update(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/threads/{thread_id}/messages/{message_id}",
body=maybe_transform({"metadata": metadata}, message_update_params.MessageUpdateParams),
body=await async_maybe_transform({"metadata": metadata}, message_update_params.MessageUpdateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
Expand Down
11 changes: 7 additions & 4 deletions src/openai/resources/beta/threads/runs/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
AsyncStepsWithStreamingResponse,
)
from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ....._utils import maybe_transform
from ....._utils import (
maybe_transform,
async_maybe_transform,
)
from ....._compat import cached_property
from ....._resource import SyncAPIResource, AsyncAPIResource
from ....._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -430,7 +433,7 @@ async def create(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/threads/{thread_id}/runs",
body=maybe_transform(
body=await async_maybe_transform(
{
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
Expand Down Expand Up @@ -521,7 +524,7 @@ async def update(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/threads/{thread_id}/runs/{run_id}",
body=maybe_transform({"metadata": metadata}, run_update_params.RunUpdateParams),
body=await async_maybe_transform({"metadata": metadata}, run_update_params.RunUpdateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
Expand Down Expand Up @@ -669,7 +672,7 @@ async def submit_tool_outputs(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/threads/{thread_id}/runs/{run_id}/submit_tool_outputs",
body=maybe_transform(
body=await async_maybe_transform(
{"tool_outputs": tool_outputs}, run_submit_tool_outputs_params.RunSubmitToolOutputsParams
),
options=make_request_options(
Expand Down
11 changes: 7 additions & 4 deletions src/openai/resources/beta/threads/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
AsyncMessagesWithStreamingResponse,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._utils import (
maybe_transform,
async_maybe_transform,
)
from .runs.runs import Runs, AsyncRuns
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
Expand Down Expand Up @@ -342,7 +345,7 @@ async def create(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
"/threads",
body=maybe_transform(
body=await async_maybe_transform(
{
"messages": messages,
"metadata": metadata,
Expand Down Expand Up @@ -423,7 +426,7 @@ async def update(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
f"/threads/{thread_id}",
body=maybe_transform({"metadata": metadata}, thread_update_params.ThreadUpdateParams),
body=await async_maybe_transform({"metadata": metadata}, thread_update_params.ThreadUpdateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
Expand Down Expand Up @@ -517,7 +520,7 @@ async def create_and_run(
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
return await self._post(
"/threads/runs",
body=maybe_transform(
body=await async_maybe_transform(
{
"assistant_id": assistant_id,
"instructions": instructions,
Expand Down
8 changes: 6 additions & 2 deletions src/openai/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from ... import _legacy_response
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import required_args, maybe_transform
from ..._utils import (
required_args,
maybe_transform,
async_maybe_transform,
)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
Expand Down Expand Up @@ -1329,7 +1333,7 @@ async def create(
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
return await self._post(
"/chat/completions",
body=maybe_transform(
body=await async_maybe_transform(
{
"messages": messages,
"model": model,
Expand Down
Loading

0 comments on commit 31bfc12

Please sign in to comment.