Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/llama_stack_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from . import types
from ._base_client import DefaultAsyncHttpxClient, DefaultHttpxClient
from ._types import NoneType, NOT_GIVEN, NotGiven, Omit, ProxiesTypes, Transport
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import (
AsyncClient,
AsyncLlamaStackClient,
AsyncStream,
Client,
LlamaStackClient,
RequestOptions,
Stream,
Timeout,
Transport,
AsyncClient,
AsyncStream,
RequestOptions,
LlamaStackClient,
AsyncLlamaStackClient,
)
from ._constants import DEFAULT_CONNECTION_LIMITS, DEFAULT_MAX_RETRIES, DEFAULT_TIMEOUT
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
from ._exceptions import (
APIConnectionError,
APIError,
APIResponseValidationError,
ConflictError,
NotFoundError,
APIStatusError,
RateLimitError,
APITimeoutError,
AuthenticationError,
BadRequestError,
ConflictError,
APIConnectionError,
AuthenticationError,
InternalServerError,
LlamaStackClientError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
UnprocessableEntityError,
APIResponseValidationError,
)
from ._models import BaseModel
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._utils import file_from_path
from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
from ._version import __title__, __version__

__all__ = [
"types",
Expand Down
25 changes: 18 additions & 7 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
yield LogEvent(role=None, content="", end="", color="grey")
return

step_type = event.payload.step_type.value
# TODO: discrepancy between DirectClient & HTTP Client
step_type = event.payload.step_type if type(event.payload.step_type) == str else event.payload.step_type.value

# handle safety
if step_type == "shield_call" and event_type == "step_complete":
violation = event.payload.step_details.violation
Expand All @@ -98,12 +100,21 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
color="cyan",
)
else:
yield LogEvent(
role=None,
content=event.payload.model_response_text_delta,
end="",
color="yellow",
)
# TODO: discrepancy between DirectClient & HTTP Client
if hasattr(event.payload, "model_response_text_delta"):
yield LogEvent(
role=None,
content=event.payload.model_response_text_delta,
end="",
color="yellow",
)
else:
yield LogEvent(
role=None,
content=event.payload.text_delta_model_response,
end="",
color="yellow",
)
else:
# step complete
yield LogEvent(role=None, content="")
Expand Down
8 changes: 5 additions & 3 deletions src/llama_stack_client/lib/inference/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ class EventLogger:
def log(self, event_generator):
for chunk in event_generator:
event = chunk.event
if event.event_type.value == "start":
# TODO: discrepancy between DirectClient & HTTP Client
event_type = event.event_type if type(event.event_type) == str else event.event_type.value
if event_type == "start":
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type.value == "progress":
elif event_type == "progress":
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type.value == "complete":
elif event_type == "complete":
yield LogEvent("")
104 changes: 102 additions & 2 deletions src/llama_stack_client/resources/datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

from typing import Dict, Union, Iterable

import httpx

from ..types import datasetio_get_rows_paginated_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..types import datasetio_append_rows_params, datasetio_get_rows_paginated_params
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from .._utils import (
maybe_transform,
strip_not_given,
Expand Down Expand Up @@ -45,6 +47,49 @@ def with_streaming_response(self) -> DatasetioResourceWithStreamingResponse:
"""
return DatasetioResourceWithStreamingResponse(self)

def append_rows(
self,
*,
dataset_id: str,
rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]],
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> None:
"""
Args:
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
extra_headers = {
**strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
**(extra_headers or {}),
}
return self._post(
"/alpha/datasetio/append-rows",
body=maybe_transform(
{
"dataset_id": dataset_id,
"rows": rows,
},
datasetio_append_rows_params.DatasetioAppendRowsParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=NoneType,
)

def get_rows_paginated(
self,
*,
Expand Down Expand Up @@ -115,6 +160,49 @@ def with_streaming_response(self) -> AsyncDatasetioResourceWithStreamingResponse
"""
return AsyncDatasetioResourceWithStreamingResponse(self)

async def append_rows(
self,
*,
dataset_id: str,
rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]],
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> None:
"""
Args:
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
extra_headers = {
**strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
**(extra_headers or {}),
}
return await self._post(
"/alpha/datasetio/append-rows",
body=await async_maybe_transform(
{
"dataset_id": dataset_id,
"rows": rows,
},
datasetio_append_rows_params.DatasetioAppendRowsParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=NoneType,
)

async def get_rows_paginated(
self,
*,
Expand Down Expand Up @@ -169,6 +257,9 @@ class DatasetioResourceWithRawResponse:
def __init__(self, datasetio: DatasetioResource) -> None:
self._datasetio = datasetio

self.append_rows = to_raw_response_wrapper(
datasetio.append_rows,
)
self.get_rows_paginated = to_raw_response_wrapper(
datasetio.get_rows_paginated,
)
Expand All @@ -178,6 +269,9 @@ class AsyncDatasetioResourceWithRawResponse:
def __init__(self, datasetio: AsyncDatasetioResource) -> None:
self._datasetio = datasetio

self.append_rows = async_to_raw_response_wrapper(
datasetio.append_rows,
)
self.get_rows_paginated = async_to_raw_response_wrapper(
datasetio.get_rows_paginated,
)
Expand All @@ -187,6 +281,9 @@ class DatasetioResourceWithStreamingResponse:
def __init__(self, datasetio: DatasetioResource) -> None:
self._datasetio = datasetio

self.append_rows = to_streamed_response_wrapper(
datasetio.append_rows,
)
self.get_rows_paginated = to_streamed_response_wrapper(
datasetio.get_rows_paginated,
)
Expand All @@ -196,6 +293,9 @@ class AsyncDatasetioResourceWithStreamingResponse:
def __init__(self, datasetio: AsyncDatasetioResource) -> None:
self._datasetio = datasetio

self.append_rows = async_to_streamed_response_wrapper(
datasetio.append_rows,
)
self.get_rows_paginated = async_to_streamed_response_wrapper(
datasetio.get_rows_paginated,
)
Loading
Loading