Skip to content

Integrate portkey.ai as multi-provider gateway with tool use support #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Secure Integration test

on:
pull_request_target:
types: [opened, synchronize, labeled, unlabled, reopened]

jobs:
check-access-and-checkout:
runs-on: ubuntu-latest
permissions:
id-token: write
pull-requests: read
contents: read
steps:
- name: Check PR labels and author
id: check
uses: actions/github-script@v7
with:
script: |
const pr = context.payload.pull_request;

const labels = pr.labels.map(label => label.name);
const hasLabel = labels.includes('approved-for-integ-test')
if (hasLabel) {
core.info('PR contains label approved-for-integ-test')
return
}

core.setFailed('Pull Request must either have label approved-for-integ-test')
- name: Configure Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }}
aws-region: us-east-1
mask-aws-account-id: true
- name: Checkout base branch
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }} # Pull the commit from the forked repo
persist-credentials: false # Don't persist credentials for subsequent actions
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install --no-cache-dir hatch
- name: Run integration tests
env:
AWS_REGION: us-east-1
AWS_REGION_NAME: us-east-1 # Needed for LiteLLM
id: tests
run: |
hatch test tests-integ


12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ from strands import Agent
from strands.models import BedrockModel
from strands.models.ollama import OllamaModel
from strands.models.llamaapi import LlamaAPIModel
from strands.models.portkey import PortkeyModel

# Bedrock
bedrock_model = BedrockModel(
Expand All @@ -142,6 +143,17 @@ llama_model = LlamaAPIModel(
)
agent = Agent(model=llama_model)
response = agent("Tell me about Agentic AI")

# Portkey for all models
portkey_model = PortkeyModel(
api_key="<PORTKEY_API_KEY>",
model_id="anthropic.claude-3-5-sonnet-20241022-v2:0",
virtual_key="<BEDROCK_VIRTUAL_KEY>",
provider="bedrock",
base_url="http://portkey-service-gateway.service.prod.example.com/v1",
)
agent = Agent(model=portkey_model)
response = agent("Tell me about Agentic AI")
```

Built-in providers:
Expand Down
27 changes: 23 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ dependencies = [
"watchdog>=6.0.0,<7.0.0",
"opentelemetry-api>=1.30.0,<2.0.0",
"opentelemetry-sdk>=1.30.0,<2.0.0",
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
]

[project.urls]
Expand All @@ -50,6 +49,8 @@ packages = ["src/strands"]
anthropic = [
"anthropic>=0.21.0,<1.0.0",
]
# Optional dependencies for different AI providers

dev = [
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
Expand Down Expand Up @@ -78,13 +79,28 @@ ollama = [
openai = [
"openai>=1.68.0,<2.0.0",
]
otel = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
]
a2a = [
"a2a-sdk>=0.2.6",
"uvicorn>=0.34.2",
"httpx>=0.28.1",
"fastapi>=0.115.12",
"starlette>=0.46.2",
]

portkey = [
"portkey-ai>=1.0.0,<2.0.0",
]

[tool.hatch.version]
# Tells Hatch to use your version control system (git) to determine the version.
source = "vcs"

[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai","otel", "portkey"]

dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand All @@ -107,7 +123,7 @@ lint-fix = [
]

[tool.hatch.envs.hatch-test]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "portkey"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
Expand All @@ -123,8 +139,11 @@ extra-args = [

[tool.hatch.envs.dev]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"]
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel,", "portkey"]

[tool.hatch.envs.a2a]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]


[[tool.hatch.envs.hatch-test.matrix]]
Expand Down
8 changes: 7 additions & 1 deletion src/strands/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@

from .agent import Agent
from .agent_result import AgentResult
from .conversation_manager import ConversationManager, NullConversationManager, SlidingWindowConversationManager
from .conversation_manager import (
ConversationManager,
NullConversationManager,
SlidingWindowConversationManager,
SummarizingConversationManager,
)

__all__ = [
"Agent",
"AgentResult",
"ConversationManager",
"NullConversationManager",
"SlidingWindowConversationManager",
"SummarizingConversationManager",
]
8 changes: 2 additions & 6 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,10 @@ def find_normalized_tool_name() -> Optional[str]:
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name
for (tool_name, tool) in tool_registry.items()
if tool_name.replace("-", "_") == name
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

if len(filtered_tools) > 1:
raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}")

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

Expand Down
10 changes: 9 additions & 1 deletion src/strands/agent/conversation_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
- NullConversationManager: A no-op implementation that does not modify conversation history
- SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context
size while preserving conversation coherence
- SummarizingConversationManager: An implementation that summarizes older context instead
of simply trimming it

Conversation managers help control memory usage and context length while maintaining relevant conversation state, which
is critical for effective agent interactions.
Expand All @@ -14,5 +16,11 @@
from .conversation_manager import ConversationManager
from .null_conversation_manager import NullConversationManager
from .sliding_window_conversation_manager import SlidingWindowConversationManager
from .summarizing_conversation_manager import SummarizingConversationManager

__all__ = ["ConversationManager", "NullConversationManager", "SlidingWindowConversationManager"]
__all__ = [
"ConversationManager",
"NullConversationManager",
"SlidingWindowConversationManager",
"SummarizingConversationManager",
]
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager):
invalid window states.
"""

def __init__(self, window_size: int = 40):
def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
"""Initialize the sliding window conversation manager.

Args:
window_size: Maximum number of messages to keep in the agent's history.
Defaults to 40 messages.
should_truncate_results: Truncate tool results when a message is too large for the model's context window
"""
self.window_size = window_size
self.should_truncate_results = should_truncate_results

def apply_management(self, agent: "Agent") -> None:
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.
Expand Down Expand Up @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
converted.
"""
messages = agent.messages

# Try to truncate the tool result first
last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages)
if last_message_idx_with_tool_results is not None and self.should_truncate_results:
logger.debug(
"message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results
)
results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results)
if results_truncated:
logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results)
return

# Try to trim index id when tool result cannot be truncated anymore
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size

Expand All @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:

# Overwrite message history
messages[:] = messages[trim_index:]

def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results in a message to reduce context size.

When a message contains tool results that are too large for the model's context window, this function
replaces the content of those tool results with a simple error message.

Args:
messages: The conversation message history.
msg_idx: Index of the message containing tool results to truncate.

Returns:
True if any changes were made to the message, False otherwise.
"""
if msg_idx >= len(messages) or msg_idx < 0:
return False

message = messages[msg_idx]
changes_made = False
tool_result_too_large_message = "The tool result was too large!"
for i, content in enumerate(message.get("content", [])):
if isinstance(content, dict) and "toolResult" in content:
tool_result_content_text = next(
(item["text"] for item in content["toolResult"]["content"] if "text" in item),
"",
)
# make the overwriting logic togglable
if (
message["content"][i]["toolResult"]["status"] == "error"
and tool_result_content_text == tool_result_too_large_message
):
logger.info("ToolResult has already been updated, skipping overwrite")
return False
# Update status to error with informative message
message["content"][i]["toolResult"]["status"] = "error"
message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}]
changes_made = True

return changes_made

def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]:
"""Find the index of the last message containing tool results.

This is useful for identifying messages that might need to be truncated to reduce context size.

Args:
messages: The conversation message history.

Returns:
Index of the last message with tool results, or None if no such message exists.
"""
# Iterate backwards through all messages (from newest to oldest)
for idx in range(len(messages) - 1, -1, -1):
# Check if this message has any content with toolResult
current_message = messages[idx]
has_tool_result = False

for content in current_message.get("content", []):
if isinstance(content, dict) and "toolResult" in content:
has_tool_result = True
break

if has_tool_result:
return idx

return None
Loading