Skip to content

Commit

Permalink
add some more
Browse files Browse the repository at this point in the history
  • Loading branch information
ydm-amazon committed Feb 28, 2024
1 parent 1c3f2e2 commit f48b781
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
import torch

from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties
from typing import Optional

QUANTIZATION_SUPPORT_ALGO = ["bitsandbytes8", "bitsandbytes", "gptq", "awq"]


class LmiDistRollingBatch(RollingBatch):

def __init__(self, model_id_or_path, properties, **kwargs):
def __init__(self, model_id_or_path: str, properties: dict, **kwargs) -> None:
"""
Initializes the LmiDistRollingBatch.
Expand All @@ -50,12 +51,21 @@ def __init__(self, model_id_or_path, properties, **kwargs):
self.batch_id_counter = 0
self.cache = {}

def reset(self):
def reset(self) -> None:
"""
Aborts all requests.
"""
self.cache.clear()
self.batch_id_counter = 0
super().reset()

def _init_model(self, model_id_or_path, draft_model_id=None):
def _init_model(self, model_id_or_path: str, draft_model_id: Optional[str] = None) -> None:
"""
Helper function for __init__ that creates a model in the LMIDist backend.
:param model_id_or_path: model id or path
:param draft_model_id: model ID of draft model in speculative decoding, if applicable
"""
sharded = self.lmi_dist_configs.tensor_parallel_degree > 1
quantize = self.lmi_dist_configs.quantize
if quantize is not None:
Expand All @@ -78,7 +88,11 @@ def _init_model(self, model_id_or_path, draft_model_id=None):
self.batch_cls = self.model.batch_type
self._warmup()

def _warmup(self):
def _warmup(self) -> None:
"""
Sends requests to the model before any actual requests come in to get
the rolling batch rolling.
"""
max_batch_prefill_tokens = self.lmi_dist_configs.max_rolling_batch_prefill_tokens

input_length = 512
Expand Down Expand Up @@ -120,17 +134,19 @@ def _warmup(self):
logging.info(
f"The max total sequence length is {max_batch_total_tokens}")

def release_cache(self):
def release_cache(self) -> None:
self.model.release_cache()

@stop_on_any_exception
def inference(self, input_data, parameters):
def inference(self, input_data: list[str], parameters: list[dict]) -> list:
"""
Performs prefill and decode operations for the batch.
:param input_data: List of input texts for each request in a batch
:param parameters: List of kwargs for each request in a batch
:return: generated batch decoded tokens
:return: generated batch decoded tokens - list of dictionaries, one for
each request, that contain output tokens and other data.
"""
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
Expand All @@ -140,8 +156,13 @@ def inference(self, input_data, parameters):
self._prefill_and_decode(new_batch)
return self.postprocess_results()

def _prefill_and_decode(self, new_batch):
def _prefill_and_decode(self, new_batch: Batch) -> None:
"""
Helper function for inference() - adds new requests to the batch
and gets output tokens from model for each request
:param new_batch: Contains all the new requests
About the text quality issue in Nov. 2023, it was temporarily solved by [RP#1189: Fix lmi_dist garbage output
issue](https://github.com/deepjavalibrary/djl-serving/pull/1189). The root cause of this issue is now
believed to be found. It should be the buggy memory management; the batch.release() was called inside
Expand Down Expand Up @@ -213,7 +234,15 @@ def _prefill_and_decode(self, new_batch):
self.cache[batch.batch_id] = self.cache[batch.batch_id].filter(
req_ids)

def preprocess_requests(self, requests, **kwargs):
def preprocess_requests(self, requests: list, **kwargs) -> Batch:
"""
Preprocesses requests by producing an aggregate batch object to send to
lmi-dist.
:param requests: list of Request objects
:return: An object from lmi-dist that is a kind of Batch
"""
preprocessed_requests = []
for r in requests:
param = r.parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,38 @@

class NeuronRollingBatch(RollingBatch):

def __init__(self, model, tokenizer, batch_size, n_postions, **kwargs):
def __init__(self, model, tokenizer, batch_size: int, n_positions: int, **kwargs) -> None:
"""
Initializes the NeuronRollingBatch.
:param model: the Neuron HuggingFace model
:param batch_size: the maximum batch size required by model
:param tokenizer: the tokenizer used by model
:param n_positions: the maximum sequence size for model
"""
super().__init__(**kwargs)
self.scheduler = NeuronGenerator(model, tokenizer, batch_size,
n_postions)
n_positions)

def reset(self):
def reset(self) -> None:
"""
Aborts all requests.
"""
self.scheduler.clear()
super().reset()

@stop_on_any_exception
def inference(self, input_data, parameters):
def inference(self, input_data: list[str], parameters: list[dict]) -> list:
"""
Loads new requests and gets output tokens from all currently active requests from
the Neuron backend.
:param input_data: List of input texts for each request in a batch
:param parameters: List of kwargs for each request in a batch
:return: generated batch decoded tokens - list of dictionaries, one for
each request, that contain output tokens and other data.
"""
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
batch_size)
Expand Down Expand Up @@ -85,5 +99,8 @@ def inference(self, input_data, parameters):
self.scheduler.filter(req_ids)
return self.postprocess_results()

def preprocess_requests(self, requests):
def preprocess_requests(self, requests: list):
"""
Currently not applicable for Neuron.
"""
raise NotImplementedError("Not implemented for Neuron rolling batcher")
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SchedulerRollingBatch(RollingBatch):
and other experimental features.
"""

def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
def __init__(self, model_id_or_path: str, properties: dict, **kwargs) -> None:
"""
Initializes the rolling batch scheduler.
Expand Down Expand Up @@ -120,7 +120,7 @@ def preprocess_requests(self, requests: list):

return new_requests

def _init_model_and_tokenizer(self):
def _init_model_and_tokenizer(self) -> None:
"""
Helper function for __init__ that creates a huggingface model and tokenizer.
"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def _init_model_and_tokenizer(self):

self.tokenizer_streaming = TokenizerStreaming(self.tokenizer)

def _init_scheduler(self):
def _init_scheduler(self) -> None:
"""
Helper function for __init__ that creates a scheduler equipped with
the strategies defined in properties
Expand All @@ -200,7 +200,7 @@ def _init_scheduler(self):
max_sparsity=self.scheduler_configs.max_sparsity,
max_splits=self.scheduler_configs.max_splits)

def _prefill_and_decode(self, new_requests):
def _prefill_and_decode(self, new_requests) -> None:
"""
Helper function for inference() that adds new requests to the batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TRTLLMRollingBatch(RollingBatch):
It also gets any new tokens from the backend and sends them back to the handler.
"""

def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
def __init__(self, model_id_or_path: str, properties: dict, **kwargs) -> None:
"""
Initializes the TRTLLMRollingBatch
Expand All @@ -37,7 +37,7 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
model_id_or_path, **kwargs)
self.request_cache = {}

def reset(self):
def reset(self) -> None:
"""
Stops all current requests and resets state of rolling batch portion of handler
"""
Expand All @@ -51,7 +51,7 @@ def translate_triton_params(self, parameters: dict) -> dict:
Helper function to convert DJL Serving parameter names to Triton
parameter names that TensorRT-LLM recognizes.
:param parameters (dict): Parameters pertaining to a specific request
:param parameters: Parameters pertaining to a specific request
:return: The same parameters dict, but with TensorRT-LLM style parameter names.
"""
Expand Down Expand Up @@ -79,10 +79,10 @@ def inference(self, input_data: list[str], parameters: list[dict]) -> list:
Loads new requests into the batch when there is availability, and gets output tokens from the backend
asynchronously.
:param input_data (list[str]): List of input prompts.
:param parameters (list[dict]): List of settings pertaining to each request.
:param input_data: List of input prompts.
:param parameters: List of settings pertaining to each request.
:return results (list): List of dictionaries, one for each request, that contain output tokens and other data.
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
batch_size = len(input_data)
# add pending requests to active requests list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class VLLMRollingBatch(RollingBatch):
"""

# TODO: Make properties is the only parameter, after refactoring all rolling batch handlers
def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
def __init__(self, model_id_or_path: str, properties: dict, **kwargs) -> None:
"""
Initializes the VLLMRollingBatch.
:param model_id_or_path (str): Currently unused since there is a copy inside properties
:param properties (dict): other properties of the model, such as decoder strategy
:param model_id_or_path: Currently unused since there is a copy inside properties
:param properties: other properties of the model, such as decoder strategy
"""
self.vllm_configs = VllmRbProperties(**properties)
super().__init__(waiting_steps=self.vllm_configs.waiting_steps,
Expand All @@ -71,7 +71,7 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
self.engine = LLMEngine.from_engine_args(args)
self.request_cache = OrderedDict()

def reset(self):
def reset(self) -> None:
"""
Aborts all requests
"""
Expand All @@ -80,12 +80,12 @@ def reset(self):
self.request_cache = OrderedDict()
super().reset()

def translate_vllm_params(self, parameters: dict):
def translate_vllm_params(self, parameters: dict) -> dict:
"""
Helper function to convert DJL Serving parameter names to parameter names
that VLLM recognizes.
:param parameters (dict): Parameters pertaining to a specific request
:param parameters: Parameters pertaining to a specific request
:return: The same parameters dict, but with VLLM style parameter names.
"""
Expand All @@ -104,10 +104,10 @@ def inference(self, input_data: list[str], parameters: list[dict]) -> list:
"""
Adds new requests and gets output tokens from the backend.
:param input_data (list[str]): List of input prompts.
:param parameters (list[dict]): List of settings pertaining to each request.
:param input_data: List of input prompts.
:param parameters: List of settings pertaining to each request.
:return results (list): List of dictionaries, one for each request, that contain output tokens and other data.
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
Expand Down

0 comments on commit f48b781

Please sign in to comment.