Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.71.0"
version = "0.71.1"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "kacper.kalinski@miquido.com" },
Expand Down
183 changes: 125 additions & 58 deletions src/draive/gemini/lmm_generation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import random
from collections.abc import AsyncIterator, Callable
from itertools import chain
from typing import Any, Literal, overload

try:
from google.api_core.exceptions import ResourceExhausted # type: ignore[import-untyped]
except ImportError:
# Define a fallback for when google-api-core is not available
# (optional dependencies do not get installed automatically)
class ResourceExhausted(Exception): # pragma: no cover
"""Stub raised when google-api-core is absent."""


from google.api_core.exceptions import ResourceExhausted # pyright: ignore[reportMissingImport]
from google.genai.types import (
Candidate,
Content,
Expand All @@ -20,6 +13,7 @@ class ResourceExhausted(Exception): # pragma: no cover
FunctionDeclarationDict,
GenerateContentConfigDict,
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
MediaResolution,
Modality,
SchemaDict,
Expand Down Expand Up @@ -204,33 +198,54 @@ async def _completion(
functions: list[FunctionDeclarationDict] | None,
output_decoder: Callable[[MultimodalContent], MultimodalContent],
) -> LMMOutput:
"""Generate non-streaming completion."""
try:
completion: GenerateContentResponse = await self._client.aio.models.generate_content(
model=model,
config=config,
contents=content,
)

except ResourceExhausted as exc:
ctx.record(ObservabilityLevel.WARNING, event="lmm.rate_limit")
raise RateLimitError(retry_after=0) from exc
ctx.record(
ObservabilityLevel.WARNING,
event="lmm.rate_limit",
)
raise RateLimitError(
retry_after=random.uniform(0.3, 3.0), # nosec: B311
) from exc

except Exception as exc:
raise GeminiException(f"Failed to generate Gemini completion: {exc}") from exc

self._record_usage_metrics(completion.usage_metadata, model)
self._record_usage_metrics(
completion.usage_metadata,
model=model,
)

if not completion.candidates:
raise GeminiException("Invalid Gemini completion - missing candidates!", completion)

completion_choice: Candidate = completion.candidates[0]
self._validate_finish_reason(completion_choice.finish_reason, completion)
self._validate_finish_reason(
completion_choice.finish_reason,
completion=completion,
)

completion_content = self._extract_completion_content(completion_choice)
completion_content: Content = self._extract_completion_content(completion_choice)
result_content_elements: list[MultimodalContentElement]
tool_requests: list[LMMToolRequest]
result_content_elements, tool_requests = self._process_completion_parts(completion_content)

lmm_completion = self._create_lmm_completion(result_content_elements, output_decoder)
lmm_completion: LMMCompletion | None = self._create_lmm_completion(
result_content_elements,
output_decoder=output_decoder,
)

return self._handle_completion_result(lmm_completion, tool_requests, functions)
return self._handle_completion_result(
lmm_completion,
tool_requests=tool_requests,
functions=functions,
)

async def _completion_stream(
self,
Expand All @@ -247,9 +262,16 @@ async def _completion_stream(
config=config,
contents=content,
)

except ResourceExhausted as exc:
ctx.record(ObservabilityLevel.WARNING, event="lmm.rate_limit")
raise RateLimitError(retry_after=0) from exc
ctx.record(
ObservabilityLevel.WARNING,
event="lmm.rate_limit",
)
raise RateLimitError(
retry_after=random.uniform(0.3, 3.0), # nosec: B311
) from exc

except Exception as exc:
raise GeminiException(f"Failed to initialize Gemini streaming: {exc}") from exc

Expand All @@ -268,12 +290,14 @@ async def _stream_processor( # noqa PLR0912
functions: list[FunctionDeclarationDict] | None,
output_decoder: Callable[[MultimodalContent], MultimodalContent],
):
"""Process streaming completion chunks."""
accumulated_tool_calls: list[LMMToolRequest] = []

async for completion_chunk in completion_stream:
# Record usage if available (expected in the last chunk)
self._record_usage_metrics(completion_chunk.usage_metadata, model)
self._record_usage_metrics(
completion_chunk.usage_metadata,
model=model,
)

if not completion_chunk.candidates:
continue # Skip chunks without candidates
Expand All @@ -288,8 +312,13 @@ async def _stream_processor( # noqa PLR0912
accumulated_tool_calls = self._accumulate_tool_call(
accumulated_tool_calls, element
)

else:
yield LMMStreamChunk.of(output_decoder(MultimodalContent.of(element)))
yield LMMStreamChunk.of(
output_decoder(
MultimodalContent.of(element),
),
)

# Handle finish reason
if finish_reason := completion_choice.finish_reason:
Expand All @@ -301,77 +330,106 @@ async def _stream_processor( # noqa PLR0912
"in the configuration. This indicates a mismatch between "
"the model's response and the provided tools."
)

for tool_request in accumulated_tool_calls:
ctx.record(
ObservabilityLevel.INFO,
event="lmm.tool_request",
attributes={"lmm.tool": tool_request.tool},
)
yield tool_request

else:
yield LMMStreamChunk.of(MultimodalContent.empty, eod=True)
break
yield LMMStreamChunk.of(
MultimodalContent.empty,
eod=True,
)

return # end of processing

elif finish_reason == FinishReason.MAX_TOKENS:
raise GeminiException(
"Invalid Gemini completion - exceeded maximum length!",
completion_chunk,
)

else:
raise GeminiException(
f"Gemini completion generation failed! Reason: {finish_reason}",
completion_chunk,
)

def _record_usage_metrics(self, usage_metadata, model: str) -> None:
"""Record token usage metrics."""
if usage := usage_metadata:
ctx.record(
ObservabilityLevel.INFO,
metric="lmm.input_tokens",
value=usage.prompt_token_count or 0,
unit="tokens",
attributes={"lmm.model": model},
)
ctx.record(
ObservabilityLevel.INFO,
metric="lmm.input_tokens.cached",
value=usage.cached_content_token_count or 0,
unit="tokens",
attributes={"lmm.model": model},
)
ctx.record(
ObservabilityLevel.INFO,
metric="lmm.output_tokens",
value=usage.candidates_token_count or 0,
unit="tokens",
attributes={"lmm.model": model},
)
def _record_usage_metrics(
self,
usage: GenerateContentResponseUsageMetadata | None,
*,
model: str,
) -> None:
if usage is None:
return

ctx.record(
ObservabilityLevel.INFO,
metric="lmm.input_tokens",
value=usage.prompt_token_count or 0,
unit="tokens",
attributes={"lmm.model": model},
)

ctx.record(
ObservabilityLevel.INFO,
metric="lmm.input_tokens.cached",
value=usage.cached_content_token_count or 0,
unit="tokens",
attributes={"lmm.model": model},
)
ctx.record(
ObservabilityLevel.INFO,
metric="lmm.output_tokens",
value=usage.candidates_token_count or 0,
unit="tokens",
attributes={"lmm.model": model},
)

def _validate_finish_reason(self, finish_reason, completion) -> None:
"""Validate completion finish reason."""
def _validate_finish_reason(
self,
finish_reason: FinishReason | None,
*,
completion: GenerateContentResponse,
) -> None:
match finish_reason:
case None:
pass # not finished

case FinishReason.STOP:
pass # Valid completion

case FinishReason.MAX_TOKENS:
raise GeminiException(
"Invalid Gemini completion - exceeded maximum length!", completion
"Invalid Gemini completion - exceeded maximum length!",
completion,
)

case reason:
raise GeminiException(
f"Gemini completion generation failed! Reason: {reason}", completion
f"Gemini completion generation failed! Reason: {reason}",
completion,
)

def _extract_completion_content(self, completion_choice: Candidate) -> Content:
"""Extract content from completion choice."""
def _extract_completion_content(
self,
completion_choice: Candidate,
) -> Content:
if candidate_content := completion_choice.content:
return candidate_content

else:
raise GeminiException("Missing Gemini completion content!")

def _process_completion_parts(
self, completion_content: Content
self,
completion_content: Content,
) -> tuple[list[MultimodalContentElement], list[LMMToolRequest]]:
"""Process completion parts into content elements and tool requests."""
result_content_elements: list[MultimodalContentElement] = []
tool_requests: list[LMMToolRequest] = []

Expand All @@ -388,6 +446,7 @@ def _process_completion_parts(
def _create_lmm_completion(
self,
result_content_elements: list[MultimodalContentElement],
*,
output_decoder: Callable[[MultimodalContent], MultimodalContent],
) -> LMMCompletion | None:
"""Create LMM completion from content elements."""
Expand All @@ -398,6 +457,7 @@ def _create_lmm_completion(
def _handle_completion_result(
self,
lmm_completion: LMMCompletion | None,
*,
tool_requests: list[LMMToolRequest],
functions: list[FunctionDeclarationDict] | None,
) -> LMMOutput:
Expand All @@ -409,19 +469,26 @@ def _handle_completion_result(
"in the configuration. This indicates a mismatch between "
"the model's response and the provided tools."
)

completion_tool_calls = LMMToolRequests(
content=lmm_completion.content if lmm_completion else None,
requests=tool_requests,
)

ctx.record(
ObservabilityLevel.INFO,
event="lmm.tool_requests",
attributes={"lmm.tools": [call.tool for call in tool_requests]},
)
return completion_tool_calls

elif lmm_completion:
ctx.record(ObservabilityLevel.INFO, event="lmm.completion")
ctx.record(
ObservabilityLevel.INFO,
event="lmm.completion",
)
return lmm_completion

else:
raise GeminiException("Invalid Gemini completion, missing content!")

Expand Down Expand Up @@ -491,7 +558,7 @@ def _build_request( # noqa PLR0913
"stop_sequences": stop_sequences,
# gemini safety is really bad and often triggers false positive
"safety_settings": DISABLED_SAFETY_SETTINGS,
"system_instruction": instruction if instruction is not None else {},
"system_instruction": instruction,
"tools": [{"function_declarations": functions}] if functions else None,
"tool_config": {"function_calling_config": {"mode": function_calling_mode}}
if function_calling_mode
Expand Down
1 change: 0 additions & 1 deletion src/draive/instructions/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


async def _empty(
name: str,
**extra: Any,
) -> Sequence[InstructionDeclaration]:
return ()
Expand Down
4 changes: 2 additions & 2 deletions src/draive/tools/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
"_check_availability",
availability_check
or (
lambda meta: True # available by default
lambda _: True # available by default
),
)
self._format_result: ToolResultFormatting[Result]
Expand Down Expand Up @@ -150,7 +150,7 @@ def available(self) -> bool:

except Exception as e:
ctx.log_error(
"Availability check exception",
f"Availability check of tool ({self.name}) failed, tool will be unavailable.",
exception=e,
)
return False
Expand Down
Loading