Skip to content

Commit

Permalink
[fix] fix custom input and output formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Apr 3, 2024
1 parent 157db6f commit 567cffa
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 14 deletions.
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def initialize(self, properties: dict):
"max_batch_size":
int(properties.get("max_rolling_batch_size", 4)),
"max_seq_len": int(properties.get("max_tokens", 1024)),
"tokenizer": self.tokenizer
"tokenizer": self.tokenizer,
"output_formatter": self.properties.output_formatter,
}
self.rolling_batch = DeepSpeedRollingBatch(self.model, properties,
**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def construct_kwargs(cls, properties):

# TODO remove this after refactor of all handlers
if properties['rolling_batch'].value != RollingBatchEnum.disable.value:
if properties['output_formatter']:
kwargs["output_formatter"] = properties['output_formatter']
if properties['waiting_steps']:
kwargs["waiting_steps"] = properties['waiting_steps']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
:param properties (dict): other properties of the model, such as decoder strategy
"""
self.lmi_dist_config = LmiDistRbProperties(**properties)
super().__init__(waiting_steps=self.lmi_dist_config.waiting_steps)
super().__init__(
waiting_steps=self.lmi_dist_config.waiting_steps,
output_formatter=self.lmi_dist_config.output_formatter)
self.supports_speculative_decoding = supports_speculative_decoding()
engine_kwargs = {}
if self.supports_speculative_decoding:
Expand Down
11 changes: 6 additions & 5 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def get_output_formatter(output_formatter: Union[str, Callable], stream: bool):
if output_formatter is not None:
# TODO: Support custom loading of user supplied output formatter
logging.warning(f"Unsupported output formatter: {output_formatter}")

if stream:
return _jsonlines_output_formatter
return _json_output_formatter
Expand Down Expand Up @@ -275,7 +274,7 @@ def __init__(
details: bool = False,
input_ids: list = [],
adapter=None,
output_formatter: Callable = None,
output_formatter: Union[str, Callable] = None,
):
"""
Initialize a request
Expand Down Expand Up @@ -308,8 +307,6 @@ def __init__(
self.step_token_number = 0

# output formatter
output_formatter = output_formatter or parameters.pop(
"output_formatter", None)
stream = parameters.pop("stream", False)
self.output_formatter = get_output_formatter(output_formatter, stream)

Expand Down Expand Up @@ -431,6 +428,7 @@ def __init__(self, **kwargs):
self.req_id_counter = 0
self.waiting_steps = kwargs.get("waiting_steps", None)
self.current_step = 0
self.default_output_formatter = kwargs.get("output_formatter", None)

def reset(self):
self.pending_requests = []
Expand Down Expand Up @@ -485,7 +483,10 @@ def get_new_requests(self,
details,
input_ids=self.get_tokenizer().encode(data)
if details else None,
adapter=adapter)
adapter=adapter,
output_formatter=params.pop(
"output_formatter",
self.default_output_formatter))
self.pending_requests.append(request)
self.req_id_counter += 1
# wait steps and not feeding new requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(self, model_id_or_path: str, properties: dict,
"""

self.scheduler_configs = SchedulerRbProperties(**properties)
super().__init__(waiting_steps=self.scheduler_configs.waiting_steps)
super().__init__(
waiting_steps=self.scheduler_configs.waiting_steps,
output_formatter=self.scheduler_configs.output_formatter)
self._init_model_and_tokenizer()
self._init_scheduler()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self, model_id_or_path: str, properties: dict,
:param properties: other properties of the model, such as decoder strategy
"""
self.vllm_configs = VllmRbProperties(**properties)
super().__init__(waiting_steps=self.vllm_configs.waiting_steps)
super().__init__(waiting_steps=self.vllm_configs.waiting_steps,
output_formatter=self.vllm_configs.output_formatter)
args = EngineArgs(
model=self.vllm_configs.model_id_or_path,
tensor_parallel_size=self.vllm_configs.tensor_parallel_degree,
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self):
self.initialized = False
self.trt_configs = None
self.rolling_batch = None
self.parse_input = parse_input

def initialize(self, properties: dict):
self.trt_configs = TensorRtLlmProperties(**properties)
Expand All @@ -50,7 +51,7 @@ def inference(self, inputs: Input) -> Output:
"""
outputs = Output()

input_data, input_size, parameters, errors, batch = parse_input(
input_data, input_size, parameters, errors, batch = self.parse_input(
inputs, self.rolling_batch.get_tokenizer(),
self.trt_configs.output_formatter)
if len(input_data) == 0:
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/tensorrt_llm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self):
self.model = None
self.trt_configs = None
self.initialized = False
self.parse_input = parse_input

def initialize(self, properties: dict):
self.trt_configs = TensorRtLlmProperties(**properties)
Expand All @@ -101,7 +102,7 @@ def inference(self, inputs: Input) -> Output:
"""
outputs = Output()

input_data, input_size, parameters, errors, batch = parse_input(
input_data, input_size, parameters, errors, batch = self.parse_input(
inputs, None, self.trt_configs.output_formatter)
if len(input_data) == 0:
for i in range(len(batch)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ def test_hf_all_configs(self):
"device_map": 'cpu',
"load_in_8bit": True,
"waiting_steps": 12,
"output_formatter": "jsonlines",
"torch_dtype": torch.bfloat16
})

Expand Down
2 changes: 2 additions & 0 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def set_tokenizer(self):

def set_rolling_batch(self):
if self.config.rolling_batch != "disable":
self.rolling_batch_config[
"output_formatter"] = self.config.output_formatter
self.rolling_batch = NeuronRollingBatch(
self.model, self.tokenizer, self.config.batch_size,
self.config.n_positions, **self.rolling_batch_config)
Expand Down

0 comments on commit 567cffa

Please sign in to comment.