@@ -380,6 +380,7 @@ def generate(
380
380
lora_request : Optional [LoRARequest ] = None ,
381
381
trace_headers : Optional [Mapping [str , str ]] = None ,
382
382
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
383
+ priority : int = 0 ,
383
384
) -> AsyncGenerator [RequestOutput , None ]:
384
385
...
385
386
@@ -392,6 +393,7 @@ def generate(
392
393
lora_request : Optional [LoRARequest ] = None ,
393
394
trace_headers : Optional [Mapping [str , str ]] = None ,
394
395
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
396
+ priority : int = 0 ,
395
397
) -> AsyncGenerator [RequestOutput , None ]:
396
398
...
397
399
@@ -407,6 +409,7 @@ def generate(
407
409
lora_request : Optional [LoRARequest ] = None ,
408
410
trace_headers : Optional [Mapping [str , str ]] = None ,
409
411
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
412
+ priority : int = 0 ,
410
413
* ,
411
414
inputs : Optional [PromptType ] = None # DEPRECATED
412
415
) -> AsyncGenerator [RequestOutput , None ]:
@@ -425,6 +428,9 @@ def generate(
425
428
trace_headers: OpenTelemetry trace headers.
426
429
prompt_adapter_request: Prompt Adapter request to use
427
430
for generation, if any.
431
+ priority: Priority of the request (lower means earlier handling).
432
+ Any priority other than 0 will lead to an error if the
433
+ scheduling policy is not "priority".
428
434
"""
429
435
if inputs is not None :
430
436
prompt = inputs
@@ -433,7 +439,7 @@ def generate(
433
439
434
440
return self ._process_request (prompt , sampling_params , request_id ,
435
441
lora_request , trace_headers ,
436
- prompt_adapter_request )
442
+ prompt_adapter_request , priority )
437
443
438
444
@overload # DEPRECATED
439
445
def encode (
@@ -444,6 +450,7 @@ def encode(
444
450
request_id : str ,
445
451
lora_request : Optional [LoRARequest ] = None ,
446
452
trace_headers : Optional [Mapping [str , str ]] = None ,
453
+ priority : int = 0 ,
447
454
) -> AsyncGenerator [EmbeddingRequestOutput , None ]:
448
455
...
449
456
@@ -455,6 +462,7 @@ def encode(
455
462
request_id : str ,
456
463
lora_request : Optional [LoRARequest ] = None ,
457
464
trace_headers : Optional [Mapping [str , str ]] = None ,
465
+ priority : int = 0 ,
458
466
) -> AsyncGenerator [EmbeddingRequestOutput , None ]:
459
467
...
460
468
@@ -469,6 +477,7 @@ def encode(
469
477
request_id : Optional [str ] = None ,
470
478
lora_request : Optional [LoRARequest ] = None ,
471
479
trace_headers : Optional [Mapping [str , str ]] = None ,
480
+ priority : int = 0 ,
472
481
* ,
473
482
inputs : Optional [PromptType ] = None # DEPRECATED
474
483
) -> AsyncGenerator [EmbeddingRequestOutput , None ]:
@@ -496,7 +505,7 @@ def encode(
496
505
and request_id is not None )
497
506
498
507
return self ._process_request (prompt , pooling_params , request_id ,
499
- lora_request , trace_headers )
508
+ lora_request , trace_headers , priority )
500
509
501
510
async def _process_request (
502
511
self ,
@@ -505,7 +514,8 @@ async def _process_request(
505
514
request_id : str ,
506
515
lora_request : Optional [LoRARequest ] = None ,
507
516
trace_headers : Optional [Mapping [str , str ]] = None ,
508
- prompt_adapter_request : Optional [PromptAdapterRequest ] = None
517
+ prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
518
+ priority : int = 0 ,
509
519
) -> Union [AsyncGenerator [RequestOutput , None ], AsyncGenerator [
510
520
EmbeddingRequestOutput , None ]]:
511
521
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -550,7 +560,9 @@ async def _process_request(
550
560
request_id = request_id ,
551
561
lora_request = lora_request ,
552
562
trace_headers = trace_headers ,
553
- prompt_adapter_request = prompt_adapter_request ))
563
+ prompt_adapter_request = prompt_adapter_request ,
564
+ priority = priority ,
565
+ ))
554
566
555
567
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
556
568
parts = (request_bytes ,
0 commit comments