|
1 | 1 | import asyncio
|
| 2 | +import time |
| 3 | +from dataclasses import dataclass |
2 | 4 | from typing import (
|
3 | 5 | Optional,
|
4 | 6 | List,
|
5 | 7 | Dict,
|
6 | 8 | Any,
|
7 |
| - AsyncIterator, |
8 | 9 | Union,
|
9 | 10 | )
|
10 | 11 |
|
11 |
| -from fastapi import HTTPException |
12 | 12 | from loguru import logger
|
13 | 13 | from openai.types.chat import ChatCompletionMessageParam
|
14 |
| -from transformers import PreTrainedTokenizer |
| 14 | +from openai.types.completion_choice import Logprobs |
| 15 | +from openai.types.model import Model |
| 16 | +from pydantic import BaseModel |
15 | 17 | from vllm.engine.async_llm_engine import AsyncLLMEngine
|
16 |
| -from vllm.sampling_params import SamplingParams |
| 18 | +from vllm.transformers_utils.tokenizer import get_tokenizer |
17 | 19 |
|
18 | 20 | from api.adapter import get_prompt_adapter
|
19 | 21 | from api.generation import build_qwen_chat_input
|
20 | 22 |
|
21 | 23 |
|
| 24 | +@dataclass |
| 25 | +class LoRA: |
| 26 | + name: str |
| 27 | + local_path: str |
| 28 | + |
| 29 | + |
| 30 | +class ModelList(BaseModel): |
| 31 | + object: str = "list" |
| 32 | + data: List[Model] = [] |
| 33 | + |
| 34 | + |
22 | 35 | class VllmEngine:
|
23 | 36 | def __init__(
|
24 | 37 | self,
|
25 | 38 | model: AsyncLLMEngine,
|
26 |
| - tokenizer: PreTrainedTokenizer, |
27 | 39 | model_name: str,
|
28 | 40 | prompt_name: Optional[str] = None,
|
29 |
| - context_len: Optional[int] = -1, |
| 41 | + lora_modules: Optional[List[LoRA]] = None, |
30 | 42 | ):
|
31 | 43 | """
|
32 | 44 | Initializes the VLLMEngine object.
|
33 | 45 |
|
34 | 46 | Args:
|
35 | 47 | model: The AsyncLLMEngine object.
|
36 |
| - tokenizer: The PreTrainedTokenizer object. |
37 | 48 | model_name: The name of the model.
|
38 | 49 | prompt_name: The name of the prompt (optional).
|
39 |
| - context_len: The length of the context (optional, default=-1). |
40 | 50 | """
|
41 | 51 | self.model = model
|
42 | 52 | self.model_name = model_name.lower()
|
43 |
| - self.tokenizer = tokenizer |
44 | 53 | self.prompt_name = prompt_name.lower() if prompt_name is not None else None
|
45 | 54 | self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
|
46 | 55 |
|
47 |
| - model_config = asyncio.run(self.model.get_model_config()) |
48 |
| - if "qwen" in self.model_name: |
49 |
| - self.max_model_len = context_len if context_len > 0 else 8192 |
| 56 | + if lora_modules is None: |
| 57 | + self.lora_requests = [] |
50 | 58 | else:
|
51 |
| - self.max_model_len = model_config.max_model_len |
| 59 | + try: |
| 60 | + from vllm.lora.request import LoRARequest |
| 61 | + self.lora_requests = [ |
| 62 | + LoRARequest( |
| 63 | + lora_name=lora.name, |
| 64 | + lora_int_id=i, |
| 65 | + lora_local_path=lora.local_path, |
| 66 | + ) for i, lora in enumerate(lora_modules, start=1) |
| 67 | + ] |
| 68 | + except ImportError: |
| 69 | + self.lora_requests = [] |
| 70 | + |
| 71 | + try: |
| 72 | + event_loop = asyncio.get_running_loop() |
| 73 | + except RuntimeError: |
| 74 | + event_loop = None |
| 75 | + |
| 76 | + if event_loop is not None and event_loop.is_running(): |
| 77 | + # If the current is instanced by Ray Serve, |
| 78 | + # there is already a running event loop |
| 79 | + event_loop.create_task(self._post_init()) |
| 80 | + else: |
| 81 | + # When using single vLLM without engine_use_ray |
| 82 | + asyncio.run(self._post_init()) |
| 83 | + |
| 84 | + async def _post_init(self): |
| 85 | + engine_model_config = await self.model.get_model_config() |
| 86 | + self.max_model_len = engine_model_config.max_model_len |
| 87 | + |
| 88 | + # A separate tokenizer to map token IDs to strings. |
| 89 | + self.tokenizer = get_tokenizer( |
| 90 | + engine_model_config.tokenizer, |
| 91 | + tokenizer_mode=engine_model_config.tokenizer_mode, |
| 92 | + trust_remote_code=engine_model_config.trust_remote_code, |
| 93 | + ) |
| 94 | + |
| 95 | + async def show_available_models(self) -> ModelList: |
| 96 | + """Show available models. Right now we only have one model.""" |
| 97 | + model_cards = [ |
| 98 | + Model( |
| 99 | + id=self.model_name, |
| 100 | + object="model", |
| 101 | + created=int(time.time()), |
| 102 | + owned_by="vllm" |
| 103 | + ) |
| 104 | + ] |
| 105 | + lora_cards = [ |
| 106 | + Model( |
| 107 | + id=lora.lora_name, |
| 108 | + object="model", |
| 109 | + created=int(time.time()), |
| 110 | + owned_by="vllm" |
| 111 | + ) |
| 112 | + for lora in self.lora_requests |
| 113 | + ] |
| 114 | + model_cards.extend(lora_cards) |
| 115 | + return ModelList(data=model_cards) |
| 116 | + |
| 117 | + def create_logprobs( |
| 118 | + self, |
| 119 | + token_ids: List[int], |
| 120 | + top_logprobs: Optional[List[Optional[Any]]] = None, |
| 121 | + num_output_top_logprobs: Optional[int] = None, |
| 122 | + initial_text_offset: int = 0, |
| 123 | + ): |
| 124 | + """Create OpenAI-style logprobs.""" |
| 125 | + logprobs = Logprobs() |
| 126 | + last_token_len = 0 |
| 127 | + if num_output_top_logprobs: |
| 128 | + logprobs.top_logprobs = [] |
| 129 | + |
| 130 | + for i, token_id in enumerate(token_ids): |
| 131 | + step_top_logprobs = top_logprobs[i] |
| 132 | + if step_top_logprobs is not None: |
| 133 | + token_logprob = step_top_logprobs[token_id].logprob |
| 134 | + else: |
| 135 | + token_logprob = None |
| 136 | + |
| 137 | + token = step_top_logprobs[token_id].decoded_token |
| 138 | + logprobs.tokens.append(token) |
| 139 | + logprobs.token_logprobs.append(token_logprob) |
| 140 | + |
| 141 | + if len(logprobs.text_offset) == 0: |
| 142 | + logprobs.text_offset.append(initial_text_offset) |
| 143 | + else: |
| 144 | + logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) |
| 145 | + last_token_len = len(token) |
| 146 | + |
| 147 | + if num_output_top_logprobs: |
| 148 | + logprobs.top_logprobs.append( |
| 149 | + { |
| 150 | + p.decoded_token: p.logprob |
| 151 | + for i, p in step_top_logprobs.items() |
| 152 | + } |
| 153 | + if step_top_logprobs else None |
| 154 | + ) |
| 155 | + return logprobs |
| 156 | + |
| 157 | + def _maybe_get_lora(self, model_name): |
| 158 | + for lora in self.lora_requests: |
| 159 | + if model_name == lora.lora_name: |
| 160 | + logger.info(f"Lora request: {model_name}") |
| 161 | + return lora |
| 162 | + return None |
52 | 163 |
|
53 | 164 | def apply_chat_template(
|
54 | 165 | self,
|
@@ -104,61 +215,6 @@ def convert_to_inputs(
|
104 | 215 | max_input_tokens = max(self.max_model_len - max_tokens, input_len)
|
105 | 216 | return input_ids[-max_input_tokens:]
|
106 | 217 |
|
107 |
| - def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator: |
108 |
| - """ |
109 |
| - Generates text based on the given parameters and request ID. |
110 |
| -
|
111 |
| - Args: |
112 |
| - params (Dict[str, Any]): A dictionary of parameters for text generation. |
113 |
| - request_id (str): The ID of the request. |
114 |
| -
|
115 |
| - Yields: |
116 |
| - Any: The generated text. |
117 |
| - """ |
118 |
| - max_tokens = params.get("max_tokens", 256) |
119 |
| - prompt_or_messages = params.get("prompt_or_messages") |
120 |
| - if isinstance(prompt_or_messages, list): |
121 |
| - prompt_or_messages = self.apply_chat_template( |
122 |
| - prompt_or_messages, |
123 |
| - functions=params.get("functions"), |
124 |
| - tools=params.get("tools"), |
125 |
| - ) |
126 |
| - |
127 |
| - if isinstance(prompt_or_messages, list): |
128 |
| - prompt, token_ids = None, prompt_or_messages |
129 |
| - else: |
130 |
| - prompt, token_ids = prompt_or_messages, None |
131 |
| - |
132 |
| - token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens) |
133 |
| - try: |
134 |
| - sampling_params = SamplingParams( |
135 |
| - n=params.get("n", 1), |
136 |
| - presence_penalty=params.get("presence_penalty", 0.), |
137 |
| - frequency_penalty=params.get("frequency_penalty", 0.), |
138 |
| - temperature=params.get("temperature", 0.9), |
139 |
| - top_p=params.get("top_p", 0.8), |
140 |
| - stop=params.get("stop", []), |
141 |
| - stop_token_ids=params.get("stop_token_ids", []), |
142 |
| - max_tokens=params.get("max_tokens", 256), |
143 |
| - repetition_penalty=params.get("repetition_penalty", 1.03), |
144 |
| - min_p=params.get("min_p", 0.0), |
145 |
| - best_of=params.get("best_of", 1), |
146 |
| - ignore_eos=params.get("ignore_eos", False), |
147 |
| - use_beam_search=params.get("use_beam_search", False), |
148 |
| - skip_special_tokens=params.get("skip_special_tokens", True), |
149 |
| - spaces_between_special_tokens=params.get("spaces_between_special_tokens", True), |
150 |
| - ) |
151 |
| - result_generator = self.model.generate( |
152 |
| - prompt_or_messages if isinstance(prompt_or_messages, str) else None, |
153 |
| - sampling_params, |
154 |
| - request_id, |
155 |
| - token_ids, |
156 |
| - ) |
157 |
| - except ValueError as e: |
158 |
| - raise HTTPException(status_code=400, detail=str(e)) from e |
159 |
| - |
160 |
| - return result_generator |
161 |
| - |
162 | 218 | @property
|
163 | 219 | def stop(self):
|
164 | 220 | """
|
|
0 commit comments