Skip to content

Commit 3c02657

Browse files
committed
Add test
1 parent 916b990 commit 3c02657

File tree

1 file changed

+61
-2
lines changed

1 file changed

+61
-2
lines changed

tests/inference/test_vllm_inference_engine.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tempfile
22
from pathlib import Path
33
from typing import List
4-
from unittest.mock import Mock, patch
4+
from unittest.mock import ANY, Mock, patch
55

66
import jsonlines
77
import pytest
@@ -12,6 +12,7 @@
1212

1313
try:
1414
vllm_import_failed = False
15+
from vllm.lora.request import LoRARequest # type: ignore
1516
from vllm.outputs import ( # pyright: ignore[reportMissingImports]
1617
CompletionOutput,
1718
RequestOutput,
@@ -50,9 +51,10 @@ def mock_vllm():
5051
yield mvllm
5152

5253

53-
def _get_default_model_params() -> ModelParams:
54+
def _get_default_model_params(use_lora: bool = False) -> ModelParams:
5455
return ModelParams(
5556
model_name="openai-community/gpt2",
57+
adapter_model="/path/to/adapter" if use_lora else None,
5658
trust_remote_code=True,
5759
)
5860

@@ -113,6 +115,63 @@ def test_infer_online(mock_vllm):
113115
mock_vllm_instance.chat.assert_called_once()
114116

115117

118+
@pytest.mark.skipif(vllm_import_failed, reason="vLLM not available")
119+
def test_infer_online_lora(mock_vllm):
120+
mock_vllm_instance = Mock()
121+
mock_vllm.LLM.return_value = mock_vllm_instance
122+
mock_vllm_instance.chat.return_value = [
123+
_create_vllm_output(["The first time I saw"], "123")
124+
]
125+
126+
lora_request = LoRARequest(
127+
lora_name="oumi_lora_adapter",
128+
lora_int_id=1,
129+
lora_path="/path/to/adapter",
130+
)
131+
mock_vllm.lora.request.LoRARequest.return_value = lora_request
132+
engine = VLLMInferenceEngine(_get_default_model_params(use_lora=True))
133+
conversation = Conversation(
134+
messages=[
135+
Message(
136+
content="Hello world!",
137+
role=Role.USER,
138+
),
139+
Message(
140+
content="Hello again!",
141+
role=Role.USER,
142+
),
143+
],
144+
metadata={"foo": "bar"},
145+
conversation_id="123",
146+
)
147+
expected_result = [
148+
Conversation(
149+
messages=[
150+
*conversation.messages,
151+
Message(
152+
content="The first time I saw",
153+
role=Role.ASSISTANT,
154+
),
155+
],
156+
metadata={"foo": "bar"},
157+
conversation_id="123",
158+
)
159+
]
160+
result = engine.infer_online([conversation], GenerationConfig(max_new_tokens=5))
161+
assert expected_result == result
162+
163+
mock_vllm.lora.request.LoRARequest.assert_called_once_with(
164+
lora_name="oumi_lora_adapter",
165+
lora_int_id=1,
166+
lora_path="/path/to/adapter",
167+
)
168+
mock_vllm_instance.chat.assert_called_once_with(
169+
ANY,
170+
sampling_params=ANY,
171+
lora_request=lora_request,
172+
)
173+
174+
116175
@pytest.mark.skipif(vllm_import_failed, reason="vLLM not available")
117176
def test_infer_online_empty(mock_vllm):
118177
mock_vllm_instance = Mock()

0 commit comments

Comments
 (0)