Skip to content

Commit

Permalink
add stream generation following huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed May 30, 2023
1 parent d59fcb0 commit cbbc99d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):
if "engine" in kwargs.keys():
engine = kwargs["engine"]

if engine and "transformers-neuronx" == engine:
model.reset_generation()

if generic_model_class == "CausalLM":
input_length = input_ids.shape[1]
all_decoder_input_ids = tokenized_inputs["input_ids"]
Expand Down Expand Up @@ -93,9 +96,9 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):
"use_cache": True
}
if engine and "transformers-neuronx" == engine:
model_inputs["return_dict"] = True
model_inputs = model.prepare_inputs_for_generation(
**model_inputs)
model_inputs["return_dict"] = True
outputs = model.forward(**model_inputs)

if generic_model_class == "Seq2SeqLM":
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/transformers-neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def infer(self, inputs):
"transformers-neuronx")
model_kwargs["engine"] = "transformers-neuronx"
outputs.add_stream_content(
stream_generator(self.model.model, self.tokenizer,
stream_generator(self.model, self.tokenizer,
input_text, **model_kwargs))
return outputs

Expand Down

0 comments on commit cbbc99d

Please sign in to comment.