Skip to content

Commit

Permalink
Add a test for CLI, but not fully done so disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Sep 19, 2024
1 parent 8b3ffa3 commit 132f942
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 19 deletions.
105 changes: 105 additions & 0 deletions llama_stack/cli/tests/test_stack_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from argparse import Namespace
from unittest.mock import MagicMock, patch

import pytest
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.cli.stack.build import StackBuild


# temporary while we make the tests work
pytest.skip(allow_module_level=True)


@pytest.fixture
def stack_build():
parser = MagicMock()
subparsers = MagicMock()
return StackBuild(subparsers)


def test_stack_build_initialization(stack_build):
assert stack_build.parser is not None
assert stack_build.parser.set_defaults.called_once_with(
func=stack_build._run_stack_build_command
)


@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_config(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config="test_config.yaml",
template=None,
list_templates=False,
name=None,
image_type="conda",
)

with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()

stack_build._run_stack_build_command(args)

mock_build_config.assert_called_once()
mock_build_image.assert_called_once()


@patch("llama_stack.cli.table.print_table")
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
args = Namespace(list_templates=True)

stack_build._run_stack_build_command(args)

mock_print_table.assert_called_once()


@patch("prompt_toolkit.prompt")
@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_interactive(
mock_build_image, mock_build_config, mock_prompt, stack_build
):
args = Namespace(
config=None, template=None, list_templates=False, name=None, image_type=None
)

mock_prompt.side_effect = [
"test_name",
"conda",
"meta-reference",
"test description",
]
mock_build_config.return_value = MagicMock()

stack_build._run_stack_build_command(args)

assert mock_prompt.call_count == 4
mock_build_config.assert_called_once()
mock_build_image.assert_called_once()


@patch("llama_stack.distribution.datatypes.BuildConfig")
@patch("llama_stack.distribution.build.build_image")
def test_run_stack_build_command_with_template(
mock_build_image, mock_build_config, stack_build
):
args = Namespace(
config=None,
template="test_template",
list_templates=False,
name="test_name",
image_type="docker",
)

with patch("builtins.open", MagicMock()):
with patch("yaml.safe_load") as mock_yaml_load:
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
mock_build_config.return_value = MagicMock()

stack_build._run_stack_build_command(args)

mock_build_config.assert_called_once()
mock_build_image.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ async def chat_completion(
delta="AI is a fascinating field...",
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="progress",
delta=ToolCallDelta(
content=ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search.value,
arguments={"query": "AI history"},
),
parse_status="success",
),
)
)
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type="progress",
# delta=ToolCallDelta(
# content=ToolCall(
# call_id="123",
# tool_name=BuiltinTool.brave_search.value,
# arguments={"query": "AI history"},
# ),
# parse_status="success",
# ),
# )
# )
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
Expand Down Expand Up @@ -179,10 +179,10 @@ async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
instructions="You are a helpful assistant.",
sampling_params=SamplingParams(),
tools=[
SearchToolDefinition(
name="brave_search",
api_key="test_key",
),
# SearchToolDefinition(
# name="brave_search",
# api_key="test_key",
# ),
],
tool_choice=ToolChoice.auto,
input_shields=[],
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
set -euo pipefail
set -x

stack_dir=$(dirname $THIS_DIR)
models_dir=$(dirname $(dirname $stack_dir))/llama-models
stack_dir=$(dirname $(dirname $THIS_DIR))
models_dir=$(dirname $stack_dir)/llama-models
PYTHONPATH=$models_dir:$stack_dir pytest -p no:warnings --asyncio-mode auto --tb=short

0 comments on commit 132f942

Please sign in to comment.