Skip to content

vLLM Model Provider implementation #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Prev Previous commit
Next Next commit
Update test_vllm.py
Updated test cases
  • Loading branch information
AhilanPonnusamy authored May 22, 2025
commit 1e4d14c40cc2f95204dc03ee8e4457b9014c84d2
102 changes: 43 additions & 59 deletions tests/strands/models/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
import requests
import json

from types import SimpleNamespace
from strands.models.vllm import VLLMModel


Expand Down Expand Up @@ -54,45 +57,28 @@ def test_format_request_with_system_prompt(model, messages, system_prompt):


def test_format_chunk_text():
chunk = {"choices": [{"delta": {"content": "World"}}]}
chunk = {"chunk_type": "content_delta", "data_type": "text", "data": "World"}
formatted = VLLMModel.format_chunk(None, chunk)
assert formatted == {"contentBlockDelta": {"delta": {"text": "World"}}}


def test_format_chunk_tool_call():
def test_format_chunk_tool_call_delta():
chunk = {
"choices": [{
"delta": {
"tool_calls": [{
"id": "abc123",
"function": {
"name": "get_time",
"arguments": '{"timezone":"UTC"}'
}
}]
}
}]
"chunk_type": "content_delta",
"data_type": "tool",
"data": SimpleNamespace(name="get_time", arguments={"timezone": "UTC"}),
}
formatted = VLLMModel.format_chunk(None, chunk)
assert formatted == {"toolCall": chunk["choices"][0]["delta"]["tool_calls"][0]}


def test_format_chunk_finish_reason():
chunk = {"choices": [{"finish_reason": "stop"}]}
formatted = VLLMModel.format_chunk(None, chunk)
assert formatted == {"messageStop": {"stopReason": "stop"}}


def test_format_chunk_empty():
chunk = {"choices": [{}]}
formatted = VLLMModel.format_chunk(None, chunk)
assert formatted == {}
assert "contentBlockDelta" in formatted
assert "toolUse" in formatted["contentBlockDelta"]["delta"]
assert json.loads(formatted["contentBlockDelta"]["delta"]["toolUse"]["input"])["timezone"] == "UTC"


def test_stream_response(monkeypatch, model, messages):
mock_lines = [
'data: {"choices":[{"delta":{"content":"Hello"}}]}\n',
'data: {"choices":[{"finish_reason":"stop"}]}\n',
'data: {"choices":[{"delta":{"content":" world"}}]}\n',
"data: [DONE]\n",
]

Expand All @@ -103,65 +89,63 @@ def __init__(self):
def __enter__(self):
return self

def __exit__(self, *a):
pass
def __exit__(self, *a): pass

def iter_lines(self, decode_unicode=False):
return iter(mock_lines)

monkeypatch.setattr(requests, "post", lambda *a, **kw: MockResponse())

request = model.format_request(messages)
chunks = list(model.stream(request))
chunks = list(model.stream(model.format_request(messages)))
chunk_types = [c.get("chunk_type") for c in chunks]

assert {"chunk_type": "message_start"} in chunks
assert any(chunk.get("chunk_type") == "content_delta" for chunk in chunks)
assert {"chunk_type": "content_stop", "data_type": "text"} in chunks
assert {"chunk_type": "message_stop", "data": "end_turn"} in chunks
assert "message_start" in chunk_types
assert chunk_types.count("content_delta") == 2
assert "content_stop" in chunk_types
assert "message_stop" in chunk_types


def test_stream_tool_call(monkeypatch, model, messages):
tool_call = {
"name": "current_time",
"arguments": {"timezone": "UTC"},
}
tool_call_json = json.dumps(tool_call)
data_str = json.dumps({
"choices": [
{"delta": {"content": f"<tool_call>{tool_call_json}</tool_call>"}}
]
})
mock_lines = [
'data: {"choices":[{"delta":{"tool_calls":[{"id":"abc","function":{"name":"current_time","arguments":"{\\"timezone\\": \\"UTC\\"}"}}]}}]}\n',
'data: {"choices":[{"delta":{"content":"Some answer before tool."}}]}\n',
f"data: {data_str}\n",
"data: [DONE]\n",
]

class MockResponse:
def __init__(self):
self.status_code = 200

def __enter__(self):
return self

def __exit__(self, *a):
pass

def iter_lines(self, decode_unicode=False):
return iter(mock_lines)
def __init__(self): self.status_code = 200
def __enter__(self): return self
def __exit__(self, *a): pass
def iter_lines(self, decode_unicode=False): return iter(mock_lines)

monkeypatch.setattr(requests, "post", lambda *a, **kw: MockResponse())

request = model.format_request(messages)
chunks = list(model.stream(request))
chunks = list(model.stream(model.format_request(messages)))
tool_chunks = [c for c in chunks if c.get("chunk_type") == "content_start" and c.get("data_type") == "tool"]

assert tool_chunks
assert any("tool_use" in c.get("chunk_type", "") or "tool" in c.get("data_type", "") for c in chunks)

assert any("toolCallStart" in c for c in chunks)
assert any("toolCallDelta" in c for c in chunks)


def test_stream_server_error(monkeypatch, model, messages):
class ErrorResponse:
def __init__(self):
self.status_code = 500
self.text = "Internal Error"

def __enter__(self):
return self

def __exit__(self, *a):
pass

def iter_lines(self, decode_unicode=False):
return iter([])
def __enter__(self): return self
def __exit__(self, *a): pass
def iter_lines(self, decode_unicode=False): return iter([])

monkeypatch.setattr(requests, "post", lambda *a, **kw: ErrorResponse())

Expand Down