Skip to content

Commit

Permalink
Add a test runner and 2 very simple tests for agents
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Sep 19, 2024
1 parent 543222a commit abb4393
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ ignore =
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# random naming hints don't need
N802,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B950
optional-ascii-coding = True
Expand Down
2 changes: 0 additions & 2 deletions llama_stack/distribution/control_plane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .control_plane import * # noqa: F401 F403
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ async def _should_retrieve_context(
else:
return True

print(f"{enabled_tools=}")
return AgentTool.memory.value in enabled_tools

def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
Expand Down
4 changes: 1 addition & 3 deletions llama_stack/providers/impls/meta_reference/agents/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ async def run_shields(
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)

res = await self.safety_api.run_shields(
results = await self.safety_api.run_shields(
messages=messages,
shields=shields,
)

results = res.responses
for shield, r in zip(shields, results):
if r.is_violation:
if shield.on_violation_action == OnViolationAction.RAISE:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import AsyncIterator, List, Optional, Union
from unittest.mock import MagicMock

import pytest

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403

from ..agent_instance import ChatAgent


class MockInferenceAPI:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
if stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="start",
delta="",
)
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="progress",
delta="Mock response",
)
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
delta="",
stop_reason="end_of_turn",
)
)
else:
yield ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant", content="Mock response", stop_reason="end_of_turn"
),
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
)


class MockSafetyAPI:
async def run_shields(
self, messages: List[Message], shields: List[MagicMock]
) -> List[ShieldResponse]:
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]


class MockMemoryAPI:
def __init__(self):
self.memory_banks = {}
self.documents = {}

async def create_memory_bank(self, name, config, url=None):
bank_id = f"bank_{len(self.memory_banks)}"
bank = MemoryBank(bank_id, name, config, url)
self.memory_banks[bank_id] = bank
self.documents[bank_id] = {}
return bank

async def list_memory_banks(self):
return list(self.memory_banks.values())

async def get_memory_bank(self, bank_id):
return self.memory_banks.get(bank_id)

async def drop_memory_bank(self, bank_id):
if bank_id in self.memory_banks:
del self.memory_banks[bank_id]
del self.documents[bank_id]
return bank_id

async def insert_documents(self, bank_id, documents, ttl_seconds=None):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc in documents:
self.documents[bank_id][doc.document_id] = doc

async def update_documents(self, bank_id, documents):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc in documents:
if doc.document_id in self.documents[bank_id]:
self.documents[bank_id][doc.document_id] = doc

async def query_documents(self, bank_id, query, params=None):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
# Simple mock implementation: return all documents
chunks = [
{"content": doc.content, "token_count": 10, "document_id": doc.document_id}
for doc in self.documents[bank_id].values()
]
scores = [1.0] * len(chunks)
return {"chunks": chunks, "scores": scores}

async def get_documents(self, bank_id, document_ids):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
return [
self.documents[bank_id][doc_id]
for doc_id in document_ids
if doc_id in self.documents[bank_id]
]

async def delete_documents(self, bank_id, document_ids):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc_id in document_ids:
self.documents[bank_id].pop(doc_id, None)


@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()


@pytest.fixture
def mock_safety_api():
return MockSafetyAPI()


@pytest.fixture
def mock_memory_api():
return MockMemoryAPI()


@pytest.fixture
def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
# You'll need to adjust this based on the actual ChatAgent constructor
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
sampling_params=SamplingParams(),
tools=[],
tool_choice=ToolChoice.auto,
input_shields=[],
output_shields=[],
)
return ChatAgent(
agent_config=agent_config,
inference_api=mock_inference_api,
memory_api=mock_memory_api,
safety_api=mock_safety_api,
builtin_tools=[],
)


@pytest.mark.asyncio
async def test_chat_agent_create_session(chat_agent):
session = chat_agent.create_session("Test Session")
assert session.session_name == "Test Session"
assert session.turns == []
assert session.session_id in chat_agent.sessions


@pytest.mark.asyncio
async def test_chat_agent_create_and_execute_turn(chat_agent):
session = chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
messages=[UserMessage(content="Hello")],
)

responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)

print(responses)
assert len(responses) > 0
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
assert responses[0].event.payload.turn_id is not None


@pytest.mark.asyncio
async def test_run_shields_wrapper(chat_agent):
messages = [UserMessage(content="Test message")]
shields = [ShieldDefinition(shield_type="test_shield")]

responses = [
chunk
async for chunk in chat_agent.run_shields_wrapper(
turn_id="test_turn_id",
messages=messages,
shields=shields,
touchpoint="user-input",
)
]

assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.response.is_violation
5 changes: 5 additions & 0 deletions llama_stack/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
16 changes: 16 additions & 0 deletions llama_stack/scripts/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"

set -euo pipefail
set -x

stack_dir=$(dirname $THIS_DIR)
models_dir=$(dirname $(dirname $stack_dir))/llama-models
PYTHONPATH=$models_dir:$stack_dir pytest -p no:warnings --asyncio-mode auto --tb=short

0 comments on commit abb4393

Please sign in to comment.