Skip to content

Commit

Permalink
[TRTLLM Python backend]Fix the output format for client side batching…
Browse files Browse the repository at this point in the history
… in dynamic batch (#1718)
  • Loading branch information
sindhuvahinis authored Apr 3, 2024
1 parent a0db65c commit 0644638
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 17 deletions.
22 changes: 12 additions & 10 deletions engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from djl_python.encode_decode import encode
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.utils import parse_input
from djl_python.utils import parse_input_with_client_batch


def _get_value_based_on_tensor(value, index=None):
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self):
self.model = None
self.trt_configs = None
self.initialized = False
self.parse_input = parse_input
self.parse_input = parse_input_with_client_batch

def initialize(self, properties: dict):
self.trt_configs = TensorRtLlmProperties(**properties)
Expand All @@ -102,7 +102,7 @@ def inference(self, inputs: Input) -> Output:
"""
outputs = Output()

input_data, input_size, parameters, errors, batch = self.parse_input(
input_data, input_size, parameters, errors, batch, is_client_side_batch = self.parse_input(
inputs, None, self.trt_configs.output_formatter)
if len(input_data) == 0:
for i in range(len(batch)):
Expand All @@ -113,7 +113,7 @@ def inference(self, inputs: Input) -> Output:
params = parameters[0]
if params.get("details", False):
return self._stream_inference(inputs, input_data, input_size,
params, batch)
params, batch, is_client_side_batch)

detokenized_python_response = self.model.generate(input_data, **params)
results = [{
Expand All @@ -122,8 +122,9 @@ def inference(self, inputs: Input) -> Output:
offset = 0
for i, item in enumerate(batch):
content_type, accept = _get_accept_and_content_type(item)
batch_item = results[offset] if input_size[i] == 1 else results[
offset:offset + input_size[i]]
batch_item = results[offset:offset +
input_size[i]] if is_client_side_batch[
i] else results[offset]
encode(outputs,
batch_item,
accept,
Expand Down Expand Up @@ -159,8 +160,8 @@ def _get_config(self, properties):

# TODO TrtLLM python backend: Change it once T5 bug is fixed.
def _stream_inference(self, inputs: Input, input_data: list[str],
input_size: list[int], parameters: dict,
batch: list) -> Output:
input_size: list[int], parameters: dict, batch: list,
is_client_side_batch: list) -> Output:
outputs = Output()
detokenized_python_response = self.model.generate(
input_data, **parameters)
Expand All @@ -171,8 +172,9 @@ def _stream_inference(self, inputs: Input, input_data: list[str],
for i, item in enumerate(batch):
item = batch[i]
accept, content_type = _get_accept_and_content_type(item)
batch_item = results[offset] if input_size[i] == 1 else results[
offset:offset + input_size[i]]
batch_item = results[offset:offset +
input_size[i]] if is_client_side_batch[
i] else results[offset]
encode(outputs,
batch_item,
accept,
Expand Down
33 changes: 30 additions & 3 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request


def parse_input(
inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list]:
def parse_input_with_client_batch(
inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list, list]:
"""
Preprocessing function that extracts information from Input objects.
:param output_formatter: output formatter for the request
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:param tokenizer: the tokenizer used for inference
Expand All @@ -18,12 +19,15 @@ def parse_input(
:return parameters (list[dict]): parameters pertaining to each request
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
:return is_client_size_batch (list): list of boolean value representing whether the input is a client side batch
"""
input_data = []
input_size = []
parameters = []
errors = {}
batch = inputs.get_batches()
# only for dynamic batch
is_client_size_batch = [False for _ in range(len(batch))]
for i, item in enumerate(batch):
try:
content_type = item.get_property("Content-Type")
Expand All @@ -43,6 +47,8 @@ def parse_input(
_param["stream"] = input_map.pop("stream", False)
if not isinstance(_inputs, list):
_inputs = [_inputs]
else:
is_client_size_batch[i] = True
input_data.extend(_inputs)
input_size.append(len(_inputs))

Expand All @@ -58,4 +64,25 @@ def parse_input(
for _ in range(input_size[i]):
parameters.append(_param)

return input_data, input_size, parameters, errors, batch, is_client_size_batch


def parse_input(
inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list]:
"""
Preprocessing function that extracts information from Input objects.
:param output_formatter: output formatter for the request
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:param tokenizer: the tokenizer used for inference
:return input_data (list[str]): a list of strings, each string being the prompt in a new request
:return input_size (list[int]): a list of ints being the size of each new request
:return parameters (list[dict]): parameters pertaining to each request
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
"""
input_data, input_size, parameters, errors, batch, _ = parse_input_with_client_batch(
inputs, tokenizer, output_formatter)
return input_data, input_size, parameters, errors, batch
5 changes: 4 additions & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,8 @@ def get_model_name():
"flan-t5-xl": {
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "google/flan-t5-xl"
"tokenizer": "google/flan-t5-xl",
"details": True
}
}

Expand Down Expand Up @@ -1161,6 +1162,8 @@ def test_handler(model, model_spec):
if spec.get("adapters", []):
req["adapters"] = spec.get("adapters")
params = {"max_new_tokens": seq_length}
if spec.get("details", False):
params["details"] = True
req["parameters"] = params
logging.info(f"req {req}")
res = send_json(req)
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,9 +897,7 @@
"option.output_formatter": "jsonlines"
},
"flan-t5-xl": {
"option.model_id": "s3://djl-llm/flan-t5-xl/",
"option.rolling_batch": "disable",
"option.entryPoint": "djl_python.tensorrt_llm"
"option.model_id": "s3://djl-llm/flan-t5-xl/"
}
}

Expand Down

0 comments on commit 0644638

Please sign in to comment.