Skip to content

Commit

Permalink
Refactor vllm and rubikon engine rolling batch (#1623)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithkrn authored Mar 12, 2024
1 parent 2e29ab7 commit f7aefee
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 292 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from enum import Enum
from typing import Optional

from pydantic.v1.class_validators import validator, root_validator

from djl_python.properties_manager.properties import Properties


class LmiDistV2QuantizeMethods(str, Enum):
awq = 'awq'
gptq = 'gptq'
squeezellm = 'squeezellm'


class LmiDistV2RbProperties(Properties):
engine: Optional[str] = None
dtype: Optional[str] = "auto"
load_format: Optional[str] = "auto"
quantize: Optional[LmiDistV2QuantizeMethods] = None
tensor_parallel_degree: Optional[int] = None
max_rolling_batch_prefill_tokens: Optional[int] = None
# Adjustable prefix model length for certain 32k or longer model
max_model_len: Optional[int] = None
# TODO: change Enforce eager to False once SageMaker driver issue resolved
enforce_eager: Optional[bool] = False
# TODO: this default may change with different vLLM versions
# TODO: try to get good default from vLLM to prevent revisiting
# TODO: last time check: vllm 0.3.1
gpu_memory_utilization: Optional[float] = 0.9
# TODO: speculative decoding changes
speculative_draft_model: Optional[str] = None
speculative_length: int = 5
draft_model_tp_size: int = 1
record_acceptance_rate: Optional[bool] = False

@validator('engine')
def validate_engine(cls, engine):
if engine != "MPI":
raise AssertionError(
f"Need MPI engine to start lmidist_v2 RollingBatcher")
return engine
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pydantic.v1.class_validators import validator, root_validator

from djl_python.properties_manager.properties import Properties, RollingBatchEnum
from djl_python.properties_manager.properties import Properties


class VllmQuantizeMethods(str, Enum):
Expand Down Expand Up @@ -44,16 +44,9 @@ class VllmRbProperties(Properties):
draft_model_tp_size: int = 1
record_acceptance_rate: Optional[bool] = False

@root_validator(skip_on_failure=True)
def validate_engine(cls, properties):
engine = properties["engine"]
rolling_batch = properties["rolling_batch"]
if rolling_batch == RollingBatchEnum.vllm and engine != "Python":
@validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
raise AssertionError(
f"Need python engine to start vLLM RollingBatcher")

if rolling_batch == RollingBatchEnum.lmidist_v2 and engine != "MPI":
raise AssertionError(
f"Need MPI engine to start lmidist_v2 RollingBatcher")

return properties
return engine
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
from collections import OrderedDict

from lmi_dist.api import Request
from lmi_dist.init_engine import engine_from_args
from vllm import EngineArgs, SamplingParams

from djl_python.rolling_batch.vllm_rolling_batch_base import VllmRollingBatchBase, DTYPE_MAPPER
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
get_speculative_decoding_metrics_record, update_request_cache_with_output,
supports_speculative_decoding, DTYPE_MAPPER, FINISH_REASON_MAPPER)
from djl_python.properties_manager.lmi_dist_v2_rb_properties import LmiDistV2RbProperties


class LmiDistRollingBatch(VllmRollingBatchBase):
class LmiDistRollingBatch(RollingBatch):
"""
LmiDistRollingBatch connects handler to LmiDist backend engine. It receives new
requests from the handler and sends them to the backend when space is available in the batch.
Expand All @@ -35,29 +38,38 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
:param model_id_or_path (str): Currently unused since there is a copy inside properties
:param properties (dict): other properties of the model, such as decoder strategy
"""
engine_config = VllmRbProperties(**properties)
super().__init__(engine_config, kwargs.get("model_config", None))
self.init_engine()

def init_engine(self):
"""
Initializes vllm engine
"""
self.lmi_dist_config = LmiDistV2RbProperties(**properties)
super().__init__(
waiting_steps=self.lmi_dist_config.waiting_steps,
output_formatter=self.lmi_dist_config.output_formatter)
self.supports_speculative_decoding = supports_speculative_decoding()
engine_kwargs = {}
if self.supports_speculative_decoding:
engine_kwargs[
"draft_model"] = self.vllm_configs.speculative_draft_model
engine_kwargs[
"speculate_length"] = self.vllm_configs.speculative_length
engine_kwargs[
"draft_model_tp_size"] = self.vllm_configs.draft_model_tp_size
args = EngineArgs(
model=self.engine_config.model_id_or_path,
tensor_parallel_size=self.engine_config.tensor_parallel_degree,
dtype=DTYPE_MAPPER[self.engine_config.dtype],
model=self.lmi_dist_config.model_id_or_path,
tensor_parallel_size=self.lmi_dist_config.tensor_parallel_degree,
dtype=DTYPE_MAPPER[self.lmi_dist_config.dtype],
seed=0,
max_model_len=self.engine_config.max_model_len,
enforce_eager=self.engine_config.enforce_eager,
gpu_memory_utilization=self.engine_config.gpu_memory_utilization,
max_num_batched_tokens=self.engine_config.
max_model_len=self.lmi_dist_config.max_model_len,
enforce_eager=self.lmi_dist_config.enforce_eager,
gpu_memory_utilization=self.lmi_dist_config.gpu_memory_utilization,
max_num_batched_tokens=self.lmi_dist_config.
max_rolling_batch_prefill_tokens,
trust_remote_code=self.engine_config.trust_remote_code,
load_format=self.engine_config.load_format,
quantization=self.engine_config.quantize,
revision=self.engine_config.revision)
trust_remote_code=self.lmi_dist_config.trust_remote_code,
load_format=self.lmi_dist_config.load_format,
quantization=self.lmi_dist_config.quantize,
revision=self.lmi_dist_config.revision,
**engine_kwargs)
self.engine = engine_from_args(args)
self.request_cache = OrderedDict()
self.model_type = getattr(kwargs.get("model_config", None),
"model_type", None)

def reset(self) -> None:
"""
Expand All @@ -67,27 +79,22 @@ def reset(self) -> None:
self.request_cache = OrderedDict()
super().reset()

def add_request(self, request_id: str, prompt: str,
sampling_params: SamplingParams):
"""
Adds request to the engine
"""
lmi_dist_request = Request(id=request_id,
prompt=prompt,
sampling_params=sampling_params)
self.engine.add_request(lmi_dist_request)

def translate_to_engine_params(self, parameters: dict):
def translate_lmi_dist_params(self, parameters: dict):
"""
Helper function to convert DJL Serving parameter names to parameter names
that lmidist_v2 recognizes.
:param parameters (dict): Parameters pertaining to a specific request
:return: The same parameters dict, but with VLLM style parameter names.
:return: The same parameters dict, but with lmi-dist style parameter names.
"""
parameters.pop('seed', None)
parameters.pop('do_sample', None)
do_sample = parameters.pop('do_sample', False)
if do_sample and "temperature" not in parameters.keys():
parameters["temperature"] = 1.0
else:
parameters["temperature"] = 0.0
if "seed" in parameters.keys():
parameters["seed"] = int(parameters["seed"])
if "max_new_tokens" in parameters.keys():
parameters["max_tokens"] = parameters.pop("max_new_tokens")
if "stop_sequences" in parameters.keys():
Expand All @@ -96,15 +103,83 @@ def translate_to_engine_params(self, parameters: dict):
parameters["ignore_eos"] = parameters.pop("ignore_eos")
return parameters

def get_request_id(self, request):
@stop_on_any_exception
def inference(self, input_data: list[str], parameters: list[dict]) -> list:
"""
Get request id that will be set to backend engine request
Adds new requests and gets output tokens from the backend.
:param input_data: List of input prompts.
:param parameters: List of settings pertaining to each request.
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
return str(request.id)
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
batch_size)
# step 0: register new requests to engine
for request in new_requests:
request_id = str(request.id)
params = self.translate_lmi_dist_params(request.parameters)
sampling_params = SamplingParams(**params)
lmi_dist_request = Request(id=request_id,
prompt=request.input_text,
sampling_params=sampling_params)
self.engine.add_request(lmi_dist_request)
self.request_cache[request_id] = {
"curr_length": 0,
"text": "",
"cumulative_logprob": 0.0,
"log_prob": 0.0,
"finished": False,
"finish_reason": None
}
request_outputs = self.engine.step()

# step 1: put result to cache
for request_output in request_outputs:
self.request_cache = update_request_cache_with_output(
self.request_cache, request_output)
# Record SD metrics
completion_output = request_output.outputs[0]
if self.lmi_dist_config.record_acceptance_rate and request_output.finished:
if self.supports_speculative_decoding and completion_output.acceptance_history:
record = get_speculative_decoding_metrics_record(
completion_output, request_output)
logging.info(f"Speculative Decoding {record}")
else:
logging.warning(
f"Ignoring logging speculative decoding metrics")

# step 2: send result back
finished_id = []
for (key, cache), request in zip(self.request_cache.items(),
self.active_requests):
finish_reason = None
if cache["finished"]:
finished_id.append(key)
finish_reason = FINISH_REASON_MAPPER.get(
cache["finish_reason"], None)
text = cache["text"][cache["curr_length"]:]
if len(text) > 0:
# token id is not determined since there could be multiple token comes at the same time
# only return the last one
token = Token(cache['id'], text, cache["log_prob"])
request.set_next_token(token, self.output_formatter,
cache["finished"], finish_reason)
else:
request.set_next_token("", self.output_formatter,
cache["finished"], finish_reason)
cache["curr_length"] = len(cache["text"])

# step 3: clean finished requests
for key in finished_id:
self.request_cache.pop(key)

return self.postprocess_results()

def preprocess_requests(self, requests):
"""
Currently not applicable for VLLM.
Currently not applicable for lmi-dist-v2.
"""
raise NotImplementedError(
"Not implemented for lmidist_v2 rolling batcher")
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict
from vllm import EngineArgs
from vllm.outputs import CompletionOutput, RequestOutput

DTYPE_MAPPER = {
"fp32": "float32",
"fp16": "float16",
"bf16": "bfloat16",
"auto": "auto"
}

FINISH_REASON_MAPPER = {
"length": "length",
"stop": "eos_token",
"abort": "abort"
}


def update_request_cache_with_output(
request_cache: OrderedDict,
request_output: RequestOutput) -> OrderedDict:
request_id = request_output.request_id
request_cache[request_id]["id"] = request_output.outputs[0].token_ids[-1]
request_cache[request_id]["text"] = request_output.outputs[0].text
# calculate log_prob of the token based on the diff between two cumulative log probs
request_cache[request_id]["log_prob"] = request_output.outputs[
0].cumulative_logprob - request_cache[request_id]["cumulative_logprob"]
request_cache[request_id]["cumulative_logprob"] = request_output.outputs[
0].cumulative_logprob
request_cache[request_id]["finish_reason"] = request_output.outputs[
0].finish_reason
if len(request_output.outputs) > 1:
logging.warning(
f"Finding more than 1 output for single request {len(request_output.outputs)}"
f"Beam search is not supported yet, use first output by default")
request_cache[request_id]["finished"] = request_output.finished
return request_cache


def get_speculative_decoding_metrics_record(
completion_output: CompletionOutput,
request_output: RequestOutput) -> dict:
request_id = request_output.request_id
record = {}
record["id"] = request_id
if len(completion_output.acceptance_history) > 0:
record["mean_acceptance"] = 1.0 * sum(
completion_output.acceptance_history) / len(
completion_output.acceptance_history)
else:
record["mean_acceptance"] = 0
record["prompt_size"] = len(request_output.prompt_token_ids)
record["output_size"] = len(completion_output.token_ids)
return record


def supports_speculative_decoding() -> bool:
return "draft_model" in EngineArgs.__annotations__
Loading

0 comments on commit f7aefee

Please sign in to comment.