Skip to content

Commit 7ae2997

Browse files
schoennenbeckgarg-amit
authored andcommitted
[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (vllm-project#8965)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 90d4836 commit 7ae2997

File tree

8 files changed

+53
-5
lines changed

8 files changed

+53
-5
lines changed

vllm/engine/async_llm_engine.py

+4
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,7 @@ async def encode(
10431043
request_id: str,
10441044
lora_request: Optional[LoRARequest] = None,
10451045
trace_headers: Optional[Mapping[str, str]] = None,
1046+
priority: int = 0,
10461047
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
10471048
"""Generate outputs for a request from an embedding model.
10481049
@@ -1057,6 +1058,8 @@ async def encode(
10571058
request_id: The unique id of the request.
10581059
lora_request: LoRA request to use for generation, if any.
10591060
trace_headers: OpenTelemetry trace headers.
1061+
priority: The priority of the request.
1062+
Only applicable with priority scheduling.
10601063
10611064
Yields:
10621065
The output `EmbeddingRequestOutput` objects from the LLMEngine
@@ -1109,6 +1112,7 @@ async def encode(
11091112
pooling_params,
11101113
lora_request=lora_request,
11111114
trace_headers=trace_headers,
1115+
priority=priority,
11121116
):
11131117
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
11141118

vllm/engine/multiprocessing/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class RPCProcessRequest:
3030
lora_request: Optional[LoRARequest] = None
3131
trace_headers: Optional[Mapping[str, str]] = None
3232
prompt_adapter_request: Optional[PromptAdapterRequest] = None
33+
priority: int = 0
3334

3435
@overload # DEPRECATED
3536
def __init__(
@@ -41,6 +42,7 @@ def __init__(
4142
lora_request: Optional[LoRARequest] = None,
4243
trace_headers: Optional[Mapping[str, str]] = None,
4344
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
45+
priority: int = 0,
4446
) -> None:
4547
...
4648

@@ -53,6 +55,7 @@ def __init__(
5355
lora_request: Optional[LoRARequest] = None,
5456
trace_headers: Optional[Mapping[str, str]] = None,
5557
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
58+
priority: int = 0,
5659
) -> None:
5760
...
5861

@@ -68,6 +71,7 @@ def __init__(
6871
lora_request: Optional[LoRARequest] = None,
6972
trace_headers: Optional[Mapping[str, str]] = None,
7073
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
74+
priority: int = 0,
7175
*,
7276
inputs: Optional[PromptType] = None, # DEPRECATED
7377
) -> None:
@@ -84,6 +88,7 @@ def __init__(
8488
self.lora_request = lora_request
8589
self.trace_headers = trace_headers
8690
self.prompt_adapter_request = prompt_adapter_request
91+
self.priority = priority
8792

8893

8994
@dataclass

vllm/engine/multiprocessing/client.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def generate(
380380
lora_request: Optional[LoRARequest] = None,
381381
trace_headers: Optional[Mapping[str, str]] = None,
382382
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
383+
priority: int = 0,
383384
) -> AsyncGenerator[RequestOutput, None]:
384385
...
385386

@@ -392,6 +393,7 @@ def generate(
392393
lora_request: Optional[LoRARequest] = None,
393394
trace_headers: Optional[Mapping[str, str]] = None,
394395
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
396+
priority: int = 0,
395397
) -> AsyncGenerator[RequestOutput, None]:
396398
...
397399

@@ -407,6 +409,7 @@ def generate(
407409
lora_request: Optional[LoRARequest] = None,
408410
trace_headers: Optional[Mapping[str, str]] = None,
409411
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
412+
priority: int = 0,
410413
*,
411414
inputs: Optional[PromptType] = None # DEPRECATED
412415
) -> AsyncGenerator[RequestOutput, None]:
@@ -425,6 +428,9 @@ def generate(
425428
trace_headers: OpenTelemetry trace headers.
426429
prompt_adapter_request: Prompt Adapter request to use
427430
for generation, if any.
431+
priority: Priority of the request (lower means earlier handling).
432+
Any priority other than 0 will lead to an error if the
433+
scheduling policy is not "priority".
428434
"""
429435
if inputs is not None:
430436
prompt = inputs
@@ -433,7 +439,7 @@ def generate(
433439

434440
return self._process_request(prompt, sampling_params, request_id,
435441
lora_request, trace_headers,
436-
prompt_adapter_request)
442+
prompt_adapter_request, priority)
437443

438444
@overload # DEPRECATED
439445
def encode(
@@ -444,6 +450,7 @@ def encode(
444450
request_id: str,
445451
lora_request: Optional[LoRARequest] = None,
446452
trace_headers: Optional[Mapping[str, str]] = None,
453+
priority: int = 0,
447454
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
448455
...
449456

@@ -455,6 +462,7 @@ def encode(
455462
request_id: str,
456463
lora_request: Optional[LoRARequest] = None,
457464
trace_headers: Optional[Mapping[str, str]] = None,
465+
priority: int = 0,
458466
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
459467
...
460468

@@ -469,6 +477,7 @@ def encode(
469477
request_id: Optional[str] = None,
470478
lora_request: Optional[LoRARequest] = None,
471479
trace_headers: Optional[Mapping[str, str]] = None,
480+
priority: int = 0,
472481
*,
473482
inputs: Optional[PromptType] = None # DEPRECATED
474483
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
@@ -496,7 +505,7 @@ def encode(
496505
and request_id is not None)
497506

498507
return self._process_request(prompt, pooling_params, request_id,
499-
lora_request, trace_headers)
508+
lora_request, trace_headers, priority)
500509

501510
async def _process_request(
502511
self,
@@ -505,7 +514,8 @@ async def _process_request(
505514
request_id: str,
506515
lora_request: Optional[LoRARequest] = None,
507516
trace_headers: Optional[Mapping[str, str]] = None,
508-
prompt_adapter_request: Optional[PromptAdapterRequest] = None
517+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
518+
priority: int = 0,
509519
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
510520
EmbeddingRequestOutput, None]]:
511521
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -550,7 +560,9 @@ async def _process_request(
550560
request_id=request_id,
551561
lora_request=lora_request,
552562
trace_headers=trace_headers,
553-
prompt_adapter_request=prompt_adapter_request))
563+
prompt_adapter_request=prompt_adapter_request,
564+
priority=priority,
565+
))
554566

555567
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
556568
parts = (request_bytes,

vllm/engine/protocol.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def generate(
4040
request_id: str,
4141
lora_request: Optional[LoRARequest] = None,
4242
trace_headers: Optional[Mapping[str, str]] = None,
43-
prompt_adapter_request: Optional[PromptAdapterRequest] = None
43+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
44+
priority: int = 0,
4445
) -> AsyncGenerator[RequestOutput, None]:
4546
"""Generate outputs for a request."""
4647
...
@@ -52,6 +53,7 @@ def encode(
5253
request_id: str,
5354
lora_request: Optional[LoRARequest] = None,
5455
trace_headers: Optional[Mapping[str, str]] = None,
56+
priority: int = 0,
5557
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
5658
"""Generate outputs for a request from an embedding model."""
5759
...

vllm/entrypoints/openai/protocol.py

+22
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
279279
description=(
280280
"If specified, will override the default whitespace pattern "
281281
"for guided json decoding."))
282+
priority: int = Field(
283+
default=0,
284+
description=(
285+
"The priority of the request (lower means earlier handling; "
286+
"default: 0). Any priority other than 0 will raise an error "
287+
"if the served model does not use priority scheduling."))
282288

283289
# doc: end-chat-completion-extra-params
284290

@@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
552558
description=(
553559
"If specified, will override the default whitespace pattern "
554560
"for guided json decoding."))
561+
priority: int = Field(
562+
default=0,
563+
description=(
564+
"The priority of the request (lower means earlier handling; "
565+
"default: 0). Any priority other than 0 will raise an error "
566+
"if the served model does not use priority scheduling."))
555567

556568
# doc: end-completion-extra-params
557569

@@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
665677

666678
# doc: end-embedding-pooling-params
667679

680+
# doc: begin-embedding-extra-params
681+
priority: int = Field(
682+
default=0,
683+
description=(
684+
"The priority of the request (lower means earlier handling; "
685+
"default: 0). Any priority other than 0 will raise an error "
686+
"if the served model does not use priority scheduling."))
687+
688+
# doc: end-embedding-extra-params
689+
668690
def to_pooling_params(self):
669691
return PoolingParams(additional_data=self.additional_data)
670692

vllm/entrypoints/openai/serving_chat.py

+1
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ async def create_chat_completion(
235235
lora_request=lora_request,
236236
trace_headers=trace_headers,
237237
prompt_adapter_request=prompt_adapter_request,
238+
priority=request.priority,
238239
)
239240
except ValueError as e:
240241
# TODO: Use a vllm-specific Validation Error

vllm/entrypoints/openai/serving_completion.py

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ async def create_completion(
148148
lora_request=lora_request,
149149
prompt_adapter_request=prompt_adapter_request,
150150
trace_headers=trace_headers,
151+
priority=request.priority,
151152
)
152153

153154
generators.append(generator)

vllm/entrypoints/openai/serving_embedding.py

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ async def create_embedding(
148148
pooling_params,
149149
request_id_item,
150150
lora_request=lora_request,
151+
priority=request.priority,
151152
)
152153

153154
generators.append(generator)

0 commit comments

Comments
 (0)