Skip to content

Commit 82b0592

Browse files
author
xusenlin
committed
Update vllm version
1 parent 00cc010 commit 82b0592

File tree

9 files changed

+512
-230
lines changed

9 files changed

+512
-230
lines changed

api/config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,37 @@ class Settings(BaseModel):
176176
default=get_env("QUANTIZATION_METHOD", None),
177177
description="Quantization method for vllm server."
178178
)
179+
enforce_eager: Optional[bool] = Field(
180+
default=get_bool_env("ENFORCE_EAGER"),
181+
description="Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal performance and flexibility."
182+
)
183+
max_context_len_to_capture: Optional[int] = Field(
184+
default=int(get_env("MAX_CONTEXT_LEN_TO_CAPTURE", 8192)),
185+
description="aximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode."
186+
)
187+
max_loras: Optional[int] = Field(
188+
default=int(get_env("MAX_LORAS", 1)),
189+
description="Max number of LoRAs in a single batch."
190+
)
191+
max_lora_rank: Optional[int] = Field(
192+
default=int(get_env("MAX_LORA_RANK", 32)),
193+
description="Max LoRA rank."
194+
)
195+
lora_extra_vocab_size: Optional[int] = Field(
196+
default=int(get_env("LORA_EXTRA_VOCAB_SIZE", 256)),
197+
description="Maximum size of extra vocabulary that can be present in a LoRA adapter added to the base model vocabulary."
198+
)
199+
lora_dtype: Optional[str] = Field(
200+
default=get_env("LORA_DTYPE", "auto"),
201+
description="Data type for LoRA. If auto, will default to base model dtype."
202+
)
203+
max_cpu_loras: Optional[int] = Field(
204+
default=int(get_env("MAX_CPU_LORAS", -1)),
205+
ge=-1,
206+
)
207+
lora_modules: Optional[str] = Field(
208+
default=get_env("LORA_MODULES", ""),
209+
)
179210

180211
# support for transformers.TextIteratorStreamer
181212
use_streamer_v2: Optional[bool] = Field(

api/core/vllm_engine.py

Lines changed: 124 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,165 @@
11
import asyncio
2+
import time
3+
from dataclasses import dataclass
24
from typing import (
35
Optional,
46
List,
57
Dict,
68
Any,
7-
AsyncIterator,
89
Union,
910
)
1011

11-
from fastapi import HTTPException
1212
from loguru import logger
1313
from openai.types.chat import ChatCompletionMessageParam
14-
from transformers import PreTrainedTokenizer
14+
from openai.types.completion_choice import Logprobs
15+
from openai.types.model import Model
16+
from pydantic import BaseModel
1517
from vllm.engine.async_llm_engine import AsyncLLMEngine
16-
from vllm.sampling_params import SamplingParams
18+
from vllm.transformers_utils.tokenizer import get_tokenizer
1719

1820
from api.adapter import get_prompt_adapter
1921
from api.generation import build_qwen_chat_input
2022

2123

24+
@dataclass
25+
class LoRA:
26+
name: str
27+
local_path: str
28+
29+
30+
class ModelList(BaseModel):
31+
object: str = "list"
32+
data: List[Model] = []
33+
34+
2235
class VllmEngine:
2336
def __init__(
2437
self,
2538
model: AsyncLLMEngine,
26-
tokenizer: PreTrainedTokenizer,
2739
model_name: str,
2840
prompt_name: Optional[str] = None,
29-
context_len: Optional[int] = -1,
41+
lora_modules: Optional[List[LoRA]] = None,
3042
):
3143
"""
3244
Initializes the VLLMEngine object.
3345
3446
Args:
3547
model: The AsyncLLMEngine object.
36-
tokenizer: The PreTrainedTokenizer object.
3748
model_name: The name of the model.
3849
prompt_name: The name of the prompt (optional).
39-
context_len: The length of the context (optional, default=-1).
4050
"""
4151
self.model = model
4252
self.model_name = model_name.lower()
43-
self.tokenizer = tokenizer
4453
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
4554
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
4655

47-
model_config = asyncio.run(self.model.get_model_config())
48-
if "qwen" in self.model_name:
49-
self.max_model_len = context_len if context_len > 0 else 8192
56+
if lora_modules is None:
57+
self.lora_requests = []
5058
else:
51-
self.max_model_len = model_config.max_model_len
59+
try:
60+
from vllm.lora.request import LoRARequest
61+
self.lora_requests = [
62+
LoRARequest(
63+
lora_name=lora.name,
64+
lora_int_id=i,
65+
lora_local_path=lora.local_path,
66+
) for i, lora in enumerate(lora_modules, start=1)
67+
]
68+
except ImportError:
69+
self.lora_requests = []
70+
71+
try:
72+
event_loop = asyncio.get_running_loop()
73+
except RuntimeError:
74+
event_loop = None
75+
76+
if event_loop is not None and event_loop.is_running():
77+
# If the current is instanced by Ray Serve,
78+
# there is already a running event loop
79+
event_loop.create_task(self._post_init())
80+
else:
81+
# When using single vLLM without engine_use_ray
82+
asyncio.run(self._post_init())
83+
84+
async def _post_init(self):
85+
engine_model_config = await self.model.get_model_config()
86+
self.max_model_len = engine_model_config.max_model_len
87+
88+
# A separate tokenizer to map token IDs to strings.
89+
self.tokenizer = get_tokenizer(
90+
engine_model_config.tokenizer,
91+
tokenizer_mode=engine_model_config.tokenizer_mode,
92+
trust_remote_code=engine_model_config.trust_remote_code,
93+
)
94+
95+
async def show_available_models(self) -> ModelList:
96+
"""Show available models. Right now we only have one model."""
97+
model_cards = [
98+
Model(
99+
id=self.model_name,
100+
object="model",
101+
created=int(time.time()),
102+
owned_by="vllm"
103+
)
104+
]
105+
lora_cards = [
106+
Model(
107+
id=lora.lora_name,
108+
object="model",
109+
created=int(time.time()),
110+
owned_by="vllm"
111+
)
112+
for lora in self.lora_requests
113+
]
114+
model_cards.extend(lora_cards)
115+
return ModelList(data=model_cards)
116+
117+
def create_logprobs(
118+
self,
119+
token_ids: List[int],
120+
top_logprobs: Optional[List[Optional[Any]]] = None,
121+
num_output_top_logprobs: Optional[int] = None,
122+
initial_text_offset: int = 0,
123+
):
124+
"""Create OpenAI-style logprobs."""
125+
logprobs = Logprobs()
126+
last_token_len = 0
127+
if num_output_top_logprobs:
128+
logprobs.top_logprobs = []
129+
130+
for i, token_id in enumerate(token_ids):
131+
step_top_logprobs = top_logprobs[i]
132+
if step_top_logprobs is not None:
133+
token_logprob = step_top_logprobs[token_id].logprob
134+
else:
135+
token_logprob = None
136+
137+
token = step_top_logprobs[token_id].decoded_token
138+
logprobs.tokens.append(token)
139+
logprobs.token_logprobs.append(token_logprob)
140+
141+
if len(logprobs.text_offset) == 0:
142+
logprobs.text_offset.append(initial_text_offset)
143+
else:
144+
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
145+
last_token_len = len(token)
146+
147+
if num_output_top_logprobs:
148+
logprobs.top_logprobs.append(
149+
{
150+
p.decoded_token: p.logprob
151+
for i, p in step_top_logprobs.items()
152+
}
153+
if step_top_logprobs else None
154+
)
155+
return logprobs
156+
157+
def _maybe_get_lora(self, model_name):
158+
for lora in self.lora_requests:
159+
if model_name == lora.lora_name:
160+
logger.info(f"Lora request: {model_name}")
161+
return lora
162+
return None
52163

53164
def apply_chat_template(
54165
self,
@@ -104,61 +215,6 @@ def convert_to_inputs(
104215
max_input_tokens = max(self.max_model_len - max_tokens, input_len)
105216
return input_ids[-max_input_tokens:]
106217

107-
def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator:
108-
"""
109-
Generates text based on the given parameters and request ID.
110-
111-
Args:
112-
params (Dict[str, Any]): A dictionary of parameters for text generation.
113-
request_id (str): The ID of the request.
114-
115-
Yields:
116-
Any: The generated text.
117-
"""
118-
max_tokens = params.get("max_tokens", 256)
119-
prompt_or_messages = params.get("prompt_or_messages")
120-
if isinstance(prompt_or_messages, list):
121-
prompt_or_messages = self.apply_chat_template(
122-
prompt_or_messages,
123-
functions=params.get("functions"),
124-
tools=params.get("tools"),
125-
)
126-
127-
if isinstance(prompt_or_messages, list):
128-
prompt, token_ids = None, prompt_or_messages
129-
else:
130-
prompt, token_ids = prompt_or_messages, None
131-
132-
token_ids = self.convert_to_inputs(prompt, token_ids, max_tokens=max_tokens)
133-
try:
134-
sampling_params = SamplingParams(
135-
n=params.get("n", 1),
136-
presence_penalty=params.get("presence_penalty", 0.),
137-
frequency_penalty=params.get("frequency_penalty", 0.),
138-
temperature=params.get("temperature", 0.9),
139-
top_p=params.get("top_p", 0.8),
140-
stop=params.get("stop", []),
141-
stop_token_ids=params.get("stop_token_ids", []),
142-
max_tokens=params.get("max_tokens", 256),
143-
repetition_penalty=params.get("repetition_penalty", 1.03),
144-
min_p=params.get("min_p", 0.0),
145-
best_of=params.get("best_of", 1),
146-
ignore_eos=params.get("ignore_eos", False),
147-
use_beam_search=params.get("use_beam_search", False),
148-
skip_special_tokens=params.get("skip_special_tokens", True),
149-
spaces_between_special_tokens=params.get("spaces_between_special_tokens", True),
150-
)
151-
result_generator = self.model.generate(
152-
prompt_or_messages if isinstance(prompt_or_messages, str) else None,
153-
sampling_params,
154-
request_id,
155-
token_ids,
156-
)
157-
except ValueError as e:
158-
raise HTTPException(status_code=400, detail=str(e)) from e
159-
160-
return result_generator
161-
162218
@property
163219
def stop(self):
164220
"""

api/models.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def create_vllm_engine():
7676
try:
7777
from vllm.engine.arg_utils import AsyncEngineArgs
7878
from vllm.engine.async_llm_engine import AsyncLLMEngine
79-
from vllm.transformers_utils.tokenizer import get_tokenizer
80-
from api.core.vllm_engine import VllmEngine
79+
from api.core.vllm_engine import VllmEngine, LoRA
8180
except ImportError:
8281
return None
8382

@@ -88,32 +87,36 @@ def create_vllm_engine():
8887
"dtype",
8988
"gpu_memory_utilization",
9089
"max_num_seqs",
90+
"enforce_eager",
91+
"max_context_len_to_capture",
92+
"max_loras",
93+
"max_lora_rank",
94+
"lora_extra_vocab_size",
9195
}
9296
kwargs = model_dump(SETTINGS, include=include)
9397
engine_args = AsyncEngineArgs(
9498
model=SETTINGS.model_path,
9599
max_num_batched_tokens=SETTINGS.max_num_batched_tokens if SETTINGS.max_num_batched_tokens > 0 else None,
96100
max_model_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
97101
quantization=SETTINGS.quantization_method,
102+
max_cpu_loras=SETTINGS.max_cpu_loras if SETTINGS.max_cpu_loras > 0 else None,
98103
**kwargs,
99104
)
100105
engine = AsyncLLMEngine.from_engine_args(engine_args)
101106

102-
# A separate tokenizer to map token IDs to strings.
103-
tokenizer = get_tokenizer(
104-
engine_args.tokenizer,
105-
tokenizer_mode=engine_args.tokenizer_mode,
106-
trust_remote_code=True,
107-
)
108-
109107
logger.info("Using vllm engine")
110108

109+
lora_modules = []
110+
for item in SETTINGS.lora_modules.strip().split("+"):
111+
if "=" in item:
112+
name, path = item.split("=")
113+
lora_modules.append(LoRA(name, path))
114+
111115
return VllmEngine(
112116
engine,
113-
tokenizer,
114117
SETTINGS.model_name,
115118
SETTINGS.chat_template,
116-
SETTINGS.context_length,
119+
lora_modules=lora_modules,
117120
)
118121

119122

api/routes/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pydantic import BaseModel
77

88
from api.config import SETTINGS
9+
from api.models import GENERATE_ENGINE
910
from api.utils.request import check_api_key
1011

1112
model_router = APIRouter()
@@ -30,7 +31,7 @@ class ModelList(BaseModel):
3031

3132
@model_router.get("/models", dependencies=[Depends(check_api_key)])
3233
async def show_available_models():
33-
return available_models
34+
return await GENERATE_ENGINE.show_available_models() if SETTINGS.engine == "vllm" else available_models
3435

3536

3637
@model_router.get("/models/{model}", dependencies=[Depends(check_api_key)])

api/utils/protocol.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,16 @@ class ChatCompletionCreateParams(BaseModel):
216216

217217
min_p: Optional[float] = 0.0
218218

219+
include_stop_str_in_output: Optional[bool] = False
220+
221+
length_penalty: Optional[float] = 1.0
222+
223+
guided_json: Optional[Union[str, dict, BaseModel]] = None
224+
225+
guided_regex: Optional[str] = None
226+
227+
guided_choice: Optional[List[str]] = None
228+
219229

220230
class CompletionCreateParams(BaseModel):
221231
model: str
@@ -396,6 +406,16 @@ class CompletionCreateParams(BaseModel):
396406

397407
min_p: Optional[float] = 0.0
398408

409+
include_stop_str_in_output: Optional[bool] = False
410+
411+
length_penalty: Optional[float] = 1.0
412+
413+
guided_json: Optional[Union[str, dict, BaseModel]] = None
414+
415+
guided_regex: Optional[str] = None
416+
417+
guided_choice: Optional[List[str]] = None
418+
399419

400420
class EmbeddingCreateParams(BaseModel):
401421
input: Union[str, List[str], List[int], List[List[int]]]

0 commit comments

Comments
 (0)