diff --git a/engines/python/setup/djl_python/tensorrt_llm_python.py b/engines/python/setup/djl_python/tensorrt_llm_python.py index afca6d0c6..548199c1d 100644 --- a/engines/python/setup/djl_python/tensorrt_llm_python.py +++ b/engines/python/setup/djl_python/tensorrt_llm_python.py @@ -10,7 +10,7 @@ from djl_python.encode_decode import encode from djl_python.inputs import Input from djl_python.outputs import Output -from djl_python.utils import parse_input +from djl_python.utils import parse_input_with_client_batch def _get_value_based_on_tensor(value, index=None): @@ -84,7 +84,7 @@ def __init__(self): self.model = None self.trt_configs = None self.initialized = False - self.parse_input = parse_input + self.parse_input = parse_input_with_client_batch def initialize(self, properties: dict): self.trt_configs = TensorRtLlmProperties(**properties) @@ -102,7 +102,7 @@ def inference(self, inputs: Input) -> Output: """ outputs = Output() - input_data, input_size, parameters, errors, batch = self.parse_input( + input_data, input_size, parameters, errors, batch, is_client_side_batch = self.parse_input( inputs, None, self.trt_configs.output_formatter) if len(input_data) == 0: for i in range(len(batch)): @@ -113,7 +113,7 @@ def inference(self, inputs: Input) -> Output: params = parameters[0] if params.get("details", False): return self._stream_inference(inputs, input_data, input_size, - params, batch) + params, batch, is_client_side_batch) detokenized_python_response = self.model.generate(input_data, **params) results = [{ @@ -122,8 +122,9 @@ def inference(self, inputs: Input) -> Output: offset = 0 for i, item in enumerate(batch): content_type, accept = _get_accept_and_content_type(item) - batch_item = results[offset] if input_size[i] == 1 else results[ - offset:offset + input_size[i]] + batch_item = results[offset:offset + + input_size[i]] if is_client_side_batch[ + i] else results[offset] encode(outputs, batch_item, accept, @@ -159,8 +160,8 @@ def _get_config(self, properties): # TODO TrtLLM python backend: Change it once T5 bug is fixed. def _stream_inference(self, inputs: Input, input_data: list[str], - input_size: list[int], parameters: dict, - batch: list) -> Output: + input_size: list[int], parameters: dict, batch: list, + is_client_side_batch: list) -> Output: outputs = Output() detokenized_python_response = self.model.generate( input_data, **parameters) @@ -171,8 +172,9 @@ def _stream_inference(self, inputs: Input, input_data: list[str], for i, item in enumerate(batch): item = batch[i] accept, content_type = _get_accept_and_content_type(item) - batch_item = results[offset] if input_size[i] == 1 else results[ - offset:offset + input_size[i]] + batch_item = results[offset:offset + + input_size[i]] if is_client_side_batch[ + i] else results[offset] encode(outputs, batch_item, accept, diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index 075915f9c..22fe57617 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -4,12 +4,13 @@ from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request -def parse_input( - inputs: Input, tokenizer, output_formatter -) -> tuple[list[str], list[int], list[dict], dict, list]: +def parse_input_with_client_batch( + inputs: Input, tokenizer, output_formatter +) -> tuple[list[str], list[int], list[dict], dict, list, list]: """ Preprocessing function that extracts information from Input objects. + :param output_formatter: output formatter for the request :param inputs :(Input) a batch of inputs, each corresponding to a new request :param tokenizer: the tokenizer used for inference @@ -18,12 +19,15 @@ def parse_input( :return parameters (list[dict]): parameters pertaining to each request :return errors (dict): a dictionary mapping int indices to corresponding error strings if any :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request) + :return is_client_size_batch (list): list of boolean value representing whether the input is a client side batch """ input_data = [] input_size = [] parameters = [] errors = {} batch = inputs.get_batches() + # only for dynamic batch + is_client_size_batch = [False for _ in range(len(batch))] for i, item in enumerate(batch): try: content_type = item.get_property("Content-Type") @@ -43,6 +47,8 @@ def parse_input( _param["stream"] = input_map.pop("stream", False) if not isinstance(_inputs, list): _inputs = [_inputs] + else: + is_client_size_batch[i] = True input_data.extend(_inputs) input_size.append(len(_inputs)) @@ -58,4 +64,25 @@ def parse_input( for _ in range(input_size[i]): parameters.append(_param) + return input_data, input_size, parameters, errors, batch, is_client_size_batch + + +def parse_input( + inputs: Input, tokenizer, output_formatter +) -> tuple[list[str], list[int], list[dict], dict, list]: + """ + Preprocessing function that extracts information from Input objects. + + :param output_formatter: output formatter for the request + :param inputs :(Input) a batch of inputs, each corresponding to a new request + :param tokenizer: the tokenizer used for inference + + :return input_data (list[str]): a list of strings, each string being the prompt in a new request + :return input_size (list[int]): a list of ints being the size of each new request + :return parameters (list[dict]): parameters pertaining to each request + :return errors (dict): a dictionary mapping int indices to corresponding error strings if any + :return batch (list): a list of Input objects contained in inputs (each one corresponds to a request) + """ + input_data, input_size, parameters, errors, batch, _ = parse_input_with_client_batch( + inputs, tokenizer, output_formatter) return input_data, input_size, parameters, errors, batch diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index a35447769..a1cfb41ac 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -618,7 +618,8 @@ def get_model_name(): "flan-t5-xl": { "batch_size": [1, 4], "seq_length": [256], - "tokenizer": "google/flan-t5-xl" + "tokenizer": "google/flan-t5-xl", + "details": True } } @@ -1161,6 +1162,8 @@ def test_handler(model, model_spec): if spec.get("adapters", []): req["adapters"] = spec.get("adapters") params = {"max_new_tokens": seq_length} + if spec.get("details", False): + params["details"] = True req["parameters"] = params logging.info(f"req {req}") res = send_json(req) diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index 4cff4e2db..5a7c509c9 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -897,9 +897,7 @@ "option.output_formatter": "jsonlines" }, "flan-t5-xl": { - "option.model_id": "s3://djl-llm/flan-t5-xl/", - "option.rolling_batch": "disable", - "option.entryPoint": "djl_python.tensorrt_llm" + "option.model_id": "s3://djl-llm/flan-t5-xl/" } }