|
1 | 1 | import tempfile
|
2 | 2 | from pathlib import Path
|
3 | 3 | from typing import List
|
4 |
| -from unittest.mock import Mock, patch |
| 4 | +from unittest.mock import ANY, Mock, patch |
5 | 5 |
|
6 | 6 | import jsonlines
|
7 | 7 | import pytest
|
|
12 | 12 |
|
13 | 13 | try:
|
14 | 14 | vllm_import_failed = False
|
| 15 | + from vllm.lora.request import LoRARequest # type: ignore |
15 | 16 | from vllm.outputs import ( # pyright: ignore[reportMissingImports]
|
16 | 17 | CompletionOutput,
|
17 | 18 | RequestOutput,
|
@@ -50,9 +51,10 @@ def mock_vllm():
|
50 | 51 | yield mvllm
|
51 | 52 |
|
52 | 53 |
|
53 |
| -def _get_default_model_params() -> ModelParams: |
| 54 | +def _get_default_model_params(use_lora: bool = False) -> ModelParams: |
54 | 55 | return ModelParams(
|
55 | 56 | model_name="openai-community/gpt2",
|
| 57 | + adapter_model="/path/to/adapter" if use_lora else None, |
56 | 58 | trust_remote_code=True,
|
57 | 59 | )
|
58 | 60 |
|
@@ -113,6 +115,63 @@ def test_infer_online(mock_vllm):
|
113 | 115 | mock_vllm_instance.chat.assert_called_once()
|
114 | 116 |
|
115 | 117 |
|
| 118 | +@pytest.mark.skipif(vllm_import_failed, reason="vLLM not available") |
| 119 | +def test_infer_online_lora(mock_vllm): |
| 120 | + mock_vllm_instance = Mock() |
| 121 | + mock_vllm.LLM.return_value = mock_vllm_instance |
| 122 | + mock_vllm_instance.chat.return_value = [ |
| 123 | + _create_vllm_output(["The first time I saw"], "123") |
| 124 | + ] |
| 125 | + |
| 126 | + lora_request = LoRARequest( |
| 127 | + lora_name="oumi_lora_adapter", |
| 128 | + lora_int_id=1, |
| 129 | + lora_path="/path/to/adapter", |
| 130 | + ) |
| 131 | + mock_vllm.lora.request.LoRARequest.return_value = lora_request |
| 132 | + engine = VLLMInferenceEngine(_get_default_model_params(use_lora=True)) |
| 133 | + conversation = Conversation( |
| 134 | + messages=[ |
| 135 | + Message( |
| 136 | + content="Hello world!", |
| 137 | + role=Role.USER, |
| 138 | + ), |
| 139 | + Message( |
| 140 | + content="Hello again!", |
| 141 | + role=Role.USER, |
| 142 | + ), |
| 143 | + ], |
| 144 | + metadata={"foo": "bar"}, |
| 145 | + conversation_id="123", |
| 146 | + ) |
| 147 | + expected_result = [ |
| 148 | + Conversation( |
| 149 | + messages=[ |
| 150 | + *conversation.messages, |
| 151 | + Message( |
| 152 | + content="The first time I saw", |
| 153 | + role=Role.ASSISTANT, |
| 154 | + ), |
| 155 | + ], |
| 156 | + metadata={"foo": "bar"}, |
| 157 | + conversation_id="123", |
| 158 | + ) |
| 159 | + ] |
| 160 | + result = engine.infer_online([conversation], GenerationConfig(max_new_tokens=5)) |
| 161 | + assert expected_result == result |
| 162 | + |
| 163 | + mock_vllm.lora.request.LoRARequest.assert_called_once_with( |
| 164 | + lora_name="oumi_lora_adapter", |
| 165 | + lora_int_id=1, |
| 166 | + lora_path="/path/to/adapter", |
| 167 | + ) |
| 168 | + mock_vllm_instance.chat.assert_called_once_with( |
| 169 | + ANY, |
| 170 | + sampling_params=ANY, |
| 171 | + lora_request=lora_request, |
| 172 | + ) |
| 173 | + |
| 174 | + |
116 | 175 | @pytest.mark.skipif(vllm_import_failed, reason="vLLM not available")
|
117 | 176 | def test_infer_online_empty(mock_vllm):
|
118 | 177 | mock_vllm_instance = Mock()
|
|
0 commit comments