Skip to content

feature/ introduce new model - gemini #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ from strands import Agent
from strands.models import BedrockModel
from strands.models.ollama import OllamaModel
from strands.models.llamaapi import LlamaAPIModel
from strands.models.gemini import GeminiModel

# Bedrock
bedrock_model = BedrockModel(
Expand All @@ -130,11 +131,21 @@ llama_model = LlamaAPIModel(
)
agent = Agent(model=llama_model)
response = agent("Tell me about Agentic AI")

# Gemini
gemini_model = GeminiModel(
model_id="gemini-pro",
max_tokens=1024,
params={"temperature": 0.7}
)
agent = Agent(model=gemini_model)
response = agent("Tell me about Agentic AI")
```

Built-in providers:
- [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/)
- [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/)
- [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/)
- [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/)
- [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/)
- [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/)
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ ollama = [
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
]
gemini = [
"google-generativeai>=0.8.5",
]

[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "gemini"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand Down
269 changes: 269 additions & 0 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""Google Gemini model provider.

- Docs: https://ai.google.dev/docs/gemini_api_overview
"""

import base64
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict

import google.generativeai.generative_models as genai # mypy: disable-error-code=import
from typing_extensions import Required, Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.models import Model
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec

logger = logging.getLogger(__name__)


class GeminiModel(Model):
"""Google Gemini model provider implementation."""

EVENT_TYPES = {
"message_start",
"content_block_start",
"content_block_delta",
"content_block_stop",
"message_stop",
}

OVERFLOW_MESSAGES = {
"input is too long",
"input length exceeds context window",
"input and output tokens exceed your context limit",
}

class GeminiConfig(TypedDict, total=False):
"""Configuration options for Gemini models.

Attributes:
max_tokens: Maximum number of tokens to generate.
model_id: Gemini model ID (e.g., "gemini-pro").
For a complete list of supported models, see
https://ai.google.dev/models/gemini.
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see
https://ai.google.dev/docs/gemini_api_overview#generation_config.
"""

max_tokens: Required[int]
model_id: Required[str]
params: Optional[dict[str, Any]]

def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[GeminiConfig]):
"""Initialize provider instance.

Args:
client_args: Arguments for the underlying Gemini client (e.g., api_key).
For a complete list of supported arguments, see
https://ai.google.dev/docs/gemini_api_overview#client_libraries.
**model_config: Configuration options for the Gemini model.
"""
self.config = GeminiModel.GeminiConfig(**model_config)

logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
genai.client.configure(**client_args)
self.model = genai.GenerativeModel(self.config["model_id"])

@override
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
"""Update the Gemini model configuration with the provided arguments.

Args:
**model_config: Configuration overrides.
"""
self.config.update(model_config)
self.model = genai.GenerativeModel(self.config["model_id"])

@override
def get_config(self) -> GeminiConfig:
"""Get the Gemini model configuration.

Returns:
The Gemini model configuration.
"""
return self.config

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
"""Format a Gemini content block.

Args:
content: Message content.

Returns:
Gemini formatted content block.
"""
if "image" in content:
return {
"inline_data": {
"data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"),
"mime_type": mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream"),
}
}

if "text" in content:
return {"text": content["text"]}

return {"text": json.dumps(content)}

def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
"""Format a Gemini messages array.

Args:
messages: List of message objects to be processed by the model.

Returns:
A Gemini messages array.
"""
formatted_messages = []

for message in messages:
formatted_contents = []

for content in message["content"]:
if "cachePoint" in content:
continue

formatted_contents.append(self._format_request_message_content(content))

if formatted_contents:
formatted_messages.append({"role": message["role"], "parts": formatted_contents})

return formatted_messages

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
"""Format a Gemini streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
A Gemini streaming request.
"""
generation_config = {"max_output_tokens": self.config["max_tokens"], **(self.config.get("params") or {})}

return {
"contents": self._format_request_messages(messages),
"generation_config": generation_config,
"tools": [
{
"function_declarations": [
{
"name": tool_spec["name"],
"description": tool_spec["description"],
"parameters": tool_spec["inputSchema"]["json"],
}
for tool_spec in tool_specs or []
]
}
]
if tool_specs
else None,
"system_instruction": system_prompt,
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Gemini response events into standardized message chunks.

Args:
event: A response event from the Gemini model.

Returns:
The formatted chunk.

Raises:
RuntimeError: If chunk_type is not recognized.
This error should never be encountered as we control chunk_type in the stream method.
"""
match event["type"]:
case "message_start":
return {"messageStart": {"role": "assistant"}}

case "content_block_start":
return {"contentBlockStart": {"start": {}}}

case "content_block_delta":
return {"contentBlockDelta": {"delta": {"text": event["text"]}}}

case "content_block_stop":
return {"contentBlockStop": {}}

case "message_stop":
return {"messageStop": {"stopReason": event["stop_reason"]}}

case "metadata":
return {
"metadata": {
"usage": {
"inputTokens": event["usage"]["prompt_token_count"],
"outputTokens": event["usage"]["candidates_token_count"],
"totalTokens": event["usage"]["total_token_count"],
},
"metrics": {
"latencyMs": 0,
},
}
}

case _:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
"""Send the request to the Gemini model and get the streaming response.

Args:
request: The formatted request to send to the Gemini model.

Returns:
An iterable of response events from the Gemini model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by Gemini.
"""
try:
response = self.model.generate_content(**request, stream=True)

yield {"type": "message_start"}
yield {"type": "content_block_start"}

for chunk in response:
if chunk.text:
yield {"type": "content_block_delta", "text": chunk.text}

yield {"type": "content_block_stop"}
yield {"type": "message_stop", "stop_reason": "end_turn"}

# Get usage information
usage = response.usage_metadata
yield {
"type": "metadata",
"usage": {
"prompt_token_count": usage.prompt_token_count,
"candidates_token_count": usage.candidates_token_count,
"total_token_count": usage.total_token_count,
},
}

except Exception as error:
if "quota" in str(error).lower():
raise ModelThrottledException(str(error)) from error

if any(overflow_message in str(error).lower() for overflow_message in GeminiModel.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error

raise error
51 changes: 51 additions & 0 deletions tests-integ/test_model_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Integration tests for the Gemini model provider."""

import os

import pytest

import strands
from strands import Agent
from strands.models.gemini import GeminiModel


@pytest.fixture
def model():
return GeminiModel(
client_args={
"api_key": os.getenv("GOOGLE_API_KEY"),
},
model_id="gemini-pro",
max_tokens=512,
)


@pytest.fixture
def tools():
@strands.tool
def tool_time() -> str:
return "12:00"

@strands.tool
def tool_weather() -> str:
return "sunny"

return [tool_time, tool_weather]


@pytest.fixture
def system_prompt():
return "You are an AI assistant that uses & instead of ."


@pytest.fixture
def agent(model, tools, system_prompt):
return Agent(model=model, tools=tools, system_prompt=system_prompt)


@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY environment variable missing")
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny", "&"])
Loading