Skip to content

Commit

Permalink
parse input method overload
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Apr 2, 2024
1 parent 6308e63 commit 4685409
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def inference(self, inputs: Input) -> Output:
"""
outputs = Output()

input_data, input_size, parameters, errors, batch, _ = parse_input(
input_data, input_size, parameters, errors, batch = parse_input(
inputs, self.rolling_batch.get_tokenizer(),
self.trt_configs.output_formatter)
if len(input_data) == 0:
Expand Down
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from djl_python.encode_decode import encode
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.utils import parse_input
from djl_python.utils import parse_input_with_client_batch


def _get_value_based_on_tensor(value, index=None):
Expand Down Expand Up @@ -101,7 +101,7 @@ def inference(self, inputs: Input) -> Output:
"""
outputs = Output()

input_data, input_size, parameters, errors, batch, is_client_side_batch = parse_input(
input_data, input_size, parameters, errors, batch, is_client_side_batch = parse_input_with_client_batch(
inputs, None, self.trt_configs.output_formatter)
if len(input_data) == 0:
for i in range(len(batch)):
Expand Down
23 changes: 22 additions & 1 deletion engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request


def parse_input(
def parse_input_with_client_batch(
inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list, list]:
"""
Expand Down Expand Up @@ -65,3 +65,24 @@ def parse_input(
parameters.append(_param)

return input_data, input_size, parameters, errors, batch, is_client_size_batch


def parse_input(
inputs: Input, tokenizer, output_formatter
) -> tuple[list[str], list[int], list[dict], dict, list]:
"""
Preprocessing function that extracts information from Input objects.
:param output_formatter: output formatter for the request
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:param tokenizer: the tokenizer used for inference
:return input_data (list[str]): a list of strings, each string being the prompt in a new request
:return input_size (list[int]): a list of ints being the size of each new request
:return parameters (list[dict]): parameters pertaining to each request
:return errors (dict): a dictionary mapping int indices to corresponding error strings if any
:return batch (list): a list of Input objects contained in inputs (each one corresponds to a request)
"""
input_data, input_size, parameters, errors, batch, _ = parse_input_with_client_batch(
inputs, tokenizer, output_formatter)
return input_data, input_size, parameters, errors, batch

0 comments on commit 4685409

Please sign in to comment.