Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 192 additions & 163 deletions src/llama_stack_client/_client.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/llama_stack_client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
strip_not_given,
extract_type_arg,
is_annotated_type,
is_type_alias_type,
strip_annotated_type,
)
from ._compat import (
Expand Down Expand Up @@ -428,6 +429,8 @@ def construct_type(*, value: object, type_: object) -> object:
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
if is_type_alias_type(type_):
type_ = type_.__value__ # type: ignore[unreachable]

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
Expand Down
20 changes: 10 additions & 10 deletions src/llama_stack_client/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pydantic

from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
Expand Down Expand Up @@ -126,9 +126,15 @@ def __repr__(self) -> str:
)

def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to

# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]

# unwrap `Annotated[T, ...]` -> `T`
if to and is_annotated_type(to):
to = extract_type_arg(to, 0)
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)

if self._is_sse_stream:
if to:
Expand Down Expand Up @@ -164,18 +170,12 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
return cast(
R,
stream_cls(
cast_to=self._cast_to,
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
),
)

cast_to = to if to is not None else self._cast_to

# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)

if cast_to is NoneType:
return cast(R, None)

Expand Down
6 changes: 2 additions & 4 deletions src/llama_stack_client/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,8 @@ def get(self, __key: str) -> str | None: ...
StrBytesIntFloat = Union[str, bytes, int, float]

# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49
IncEx: TypeAlias = Union[
Set[int], Set[str], Mapping[int, Union["IncEx", Literal[True]]], Mapping[str, Union["IncEx", Literal[True]]]
]
# https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79
IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]]

PostParser = Callable[[Any], Any]

Expand Down
1 change: 1 addition & 0 deletions src/llama_stack_client/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
Expand Down
31 changes: 30 additions & 1 deletion src/llama_stack_client/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

import sys
import typing
import typing_extensions
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin
from typing_extensions import (
TypeIs,
Required,
Annotated,
get_args,
get_origin,
)

from .._types import InheritsGeneric
from .._compat import is_union as _is_union
Expand Down Expand Up @@ -36,6 +45,26 @@ def is_typevar(typ: type) -> bool:
return type(typ) == TypeVar # type: ignore


_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
if sys.version_info >= (3, 12):
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)


def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
"""Return whether the provided argument is an instance of `TypeAliasType`.

```python
type Int = int
is_type_alias_type(Int)
# > True
Str = TypeAliasType("Str", str)
is_type_alias_type(Str)
# > True
```
"""
return isinstance(tp, _TYPE_ALIAS_TYPES)


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
Expand Down
5 changes: 3 additions & 2 deletions src/llama_stack_client/resources/batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .._base_client import make_request_options
from ..types.shared.batch_completion import BatchCompletion
from ..types.shared_params.sampling_params import SamplingParams
from ..types.shared_params.interleaved_content import InterleavedContent
from ..types.batch_inference_chat_completion_response import BatchInferenceChatCompletionResponse

__all__ = ["BatchInferenceResource", "AsyncBatchInferenceResource"]
Expand Down Expand Up @@ -115,7 +116,7 @@ def chat_completion(
def completion(
self,
*,
content_batch: List[batch_inference_completion_params.ContentBatch],
content_batch: List[InterleavedContent],
model: str,
logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -244,7 +245,7 @@ async def chat_completion(
async def completion(
self,
*,
content_batch: List[batch_inference_completion_params.ContentBatch],
content_batch: List[InterleavedContent],
model: str,
logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
Expand Down
5 changes: 3 additions & 2 deletions src/llama_stack_client/resources/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from .._base_client import make_request_options
from ..types.dataset_list_response import DatasetListResponse
from ..types.shared_params.param_type import ParamType
from ..types.dataset_retrieve_response import DatasetRetrieveResponse

__all__ = ["DatasetsResource", "AsyncDatasetsResource"]
Expand Down Expand Up @@ -124,7 +125,7 @@ def register(
self,
*,
dataset_id: str,
dataset_schema: Dict[str, dataset_register_params.DatasetSchema],
dataset_schema: Dict[str, ParamType],
url: str,
metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
provider_dataset_id: str | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -306,7 +307,7 @@ async def register(
self,
*,
dataset_id: str,
dataset_schema: Dict[str, dataset_register_params.DatasetSchema],
dataset_schema: Dict[str, ParamType],
url: str,
metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
provider_dataset_id: str | NotGiven = NOT_GIVEN,
Expand Down
21 changes: 11 additions & 10 deletions src/llama_stack_client/resources/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..types.embeddings_response import EmbeddingsResponse
from ..types.inference_completion_response import InferenceCompletionResponse
from ..types.shared_params.sampling_params import SamplingParams
from ..types.shared_params.interleaved_content import InterleavedContent
from ..types.inference_chat_completion_response import InferenceChatCompletionResponse

__all__ = ["InferenceResource", "AsyncInferenceResource"]
Expand Down Expand Up @@ -245,7 +246,7 @@ def chat_completion(
def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -275,7 +276,7 @@ def completion(
def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
stream: Literal[True],
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -305,7 +306,7 @@ def completion(
def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
stream: bool,
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -335,7 +336,7 @@ def completion(
def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -383,7 +384,7 @@ def completion(
def embeddings(
self,
*,
contents: List[inference_embeddings_params.Content],
contents: List[InterleavedContent],
model_id: str,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down Expand Up @@ -631,7 +632,7 @@ async def chat_completion(
async def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -661,7 +662,7 @@ async def completion(
async def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
stream: Literal[True],
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -691,7 +692,7 @@ async def completion(
async def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
stream: bool,
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -721,7 +722,7 @@ async def completion(
async def completion(
self,
*,
content: inference_completion_params.Content,
content: InterleavedContent,
model_id: str,
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -769,7 +770,7 @@ async def completion(
async def embeddings(
self,
*,
contents: List[inference_embeddings_params.Content],
contents: List[InterleavedContent],
model_id: str,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down
5 changes: 3 additions & 2 deletions src/llama_stack_client/resources/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from .._base_client import make_request_options
from ..types.query_documents_response import QueryDocumentsResponse
from ..types.shared_params.interleaved_content import InterleavedContent

__all__ = ["MemoryResource", "AsyncMemoryResource"]

Expand Down Expand Up @@ -96,7 +97,7 @@ def query(
self,
*,
bank_id: str,
query: memory_query_params.Query,
query: InterleavedContent,
params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down Expand Up @@ -206,7 +207,7 @@ async def query(
self,
*,
bank_id: str,
query: memory_query_params.Query,
query: InterleavedContent,
params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down
5 changes: 5 additions & 0 deletions src/llama_stack_client/resources/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from typing import Dict, Union, Iterable, Optional
from typing_extensions import Literal

import httpx

Expand Down Expand Up @@ -124,6 +125,7 @@ def register(
*,
model_id: str,
metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
model_type: Literal["llm", "embedding"] | NotGiven = NOT_GIVEN,
provider_id: str | NotGiven = NOT_GIVEN,
provider_model_id: str | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -154,6 +156,7 @@ def register(
{
"model_id": model_id,
"metadata": metadata,
"model_type": model_type,
"provider_id": provider_id,
"provider_model_id": provider_model_id,
},
Expand Down Expand Up @@ -301,6 +304,7 @@ async def register(
*,
model_id: str,
metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
model_type: Literal["llm", "embedding"] | NotGiven = NOT_GIVEN,
provider_id: str | NotGiven = NOT_GIVEN,
provider_model_id: str | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -331,6 +335,7 @@ async def register(
{
"model_id": model_id,
"metadata": metadata,
"model_type": model_type,
"provider_id": provider_id,
"provider_model_id": provider_model_id,
},
Expand Down
Loading
Loading