-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor vllm and rubikon engine rolling batch (#1623)
- Loading branch information
Showing
6 changed files
with
348 additions
and
292 deletions.
There are no files selected for viewing
53 changes: 53 additions & 0 deletions
53
engines/python/setup/djl_python/properties_manager/lmi_dist_v2_rb_properties.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file | ||
# except in compliance with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from pydantic.v1.class_validators import validator, root_validator | ||
|
||
from djl_python.properties_manager.properties import Properties | ||
|
||
|
||
class LmiDistV2QuantizeMethods(str, Enum): | ||
awq = 'awq' | ||
gptq = 'gptq' | ||
squeezellm = 'squeezellm' | ||
|
||
|
||
class LmiDistV2RbProperties(Properties): | ||
engine: Optional[str] = None | ||
dtype: Optional[str] = "auto" | ||
load_format: Optional[str] = "auto" | ||
quantize: Optional[LmiDistV2QuantizeMethods] = None | ||
tensor_parallel_degree: Optional[int] = None | ||
max_rolling_batch_prefill_tokens: Optional[int] = None | ||
# Adjustable prefix model length for certain 32k or longer model | ||
max_model_len: Optional[int] = None | ||
# TODO: change Enforce eager to False once SageMaker driver issue resolved | ||
enforce_eager: Optional[bool] = False | ||
# TODO: this default may change with different vLLM versions | ||
# TODO: try to get good default from vLLM to prevent revisiting | ||
# TODO: last time check: vllm 0.3.1 | ||
gpu_memory_utilization: Optional[float] = 0.9 | ||
# TODO: speculative decoding changes | ||
speculative_draft_model: Optional[str] = None | ||
speculative_length: int = 5 | ||
draft_model_tp_size: int = 1 | ||
record_acceptance_rate: Optional[bool] = False | ||
|
||
@validator('engine') | ||
def validate_engine(cls, engine): | ||
if engine != "MPI": | ||
raise AssertionError( | ||
f"Need MPI engine to start lmidist_v2 RollingBatcher") | ||
return engine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file | ||
# except in compliance with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
import logging | ||
from collections import OrderedDict | ||
from vllm import EngineArgs | ||
from vllm.outputs import CompletionOutput, RequestOutput | ||
|
||
DTYPE_MAPPER = { | ||
"fp32": "float32", | ||
"fp16": "float16", | ||
"bf16": "bfloat16", | ||
"auto": "auto" | ||
} | ||
|
||
FINISH_REASON_MAPPER = { | ||
"length": "length", | ||
"stop": "eos_token", | ||
"abort": "abort" | ||
} | ||
|
||
|
||
def update_request_cache_with_output( | ||
request_cache: OrderedDict, | ||
request_output: RequestOutput) -> OrderedDict: | ||
request_id = request_output.request_id | ||
request_cache[request_id]["id"] = request_output.outputs[0].token_ids[-1] | ||
request_cache[request_id]["text"] = request_output.outputs[0].text | ||
# calculate log_prob of the token based on the diff between two cumulative log probs | ||
request_cache[request_id]["log_prob"] = request_output.outputs[ | ||
0].cumulative_logprob - request_cache[request_id]["cumulative_logprob"] | ||
request_cache[request_id]["cumulative_logprob"] = request_output.outputs[ | ||
0].cumulative_logprob | ||
request_cache[request_id]["finish_reason"] = request_output.outputs[ | ||
0].finish_reason | ||
if len(request_output.outputs) > 1: | ||
logging.warning( | ||
f"Finding more than 1 output for single request {len(request_output.outputs)}" | ||
f"Beam search is not supported yet, use first output by default") | ||
request_cache[request_id]["finished"] = request_output.finished | ||
return request_cache | ||
|
||
|
||
def get_speculative_decoding_metrics_record( | ||
completion_output: CompletionOutput, | ||
request_output: RequestOutput) -> dict: | ||
request_id = request_output.request_id | ||
record = {} | ||
record["id"] = request_id | ||
if len(completion_output.acceptance_history) > 0: | ||
record["mean_acceptance"] = 1.0 * sum( | ||
completion_output.acceptance_history) / len( | ||
completion_output.acceptance_history) | ||
else: | ||
record["mean_acceptance"] = 0 | ||
record["prompt_size"] = len(request_output.prompt_token_ids) | ||
record["output_size"] = len(completion_output.token_ids) | ||
return record | ||
|
||
|
||
def supports_speculative_decoding() -> bool: | ||
return "draft_model" in EngineArgs.__annotations__ |
Oops, something went wrong.