Skip to content

Commit

Permalink
use huggingface standard generation for tnx (deepjavalibrary#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored and KexinFeng committed Aug 16, 2023
1 parent 868f58a commit 5ec164d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 46 deletions.
64 changes: 22 additions & 42 deletions engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def get_stream_generator(execution_engine: str):
## execution_engine passed to this function is not the same engine specified in serving.properties
## in djl-serving. For e.g Accelerate and neuronx use Python as the engine serving.properties
## The engine here refers to backend model parallel framework.
if execution_engine in {"DeepSpeed", "Accelerate"}:
if execution_engine in {
"DeepSpeed", "Accelerate", "transformers-neuronx"
}:
return StreamingUtils._hf_model_stream_generator
elif execution_engine == "transformers-neuronx":
return StreamingUtils._transformers_neuronx_stream_generator
else:
raise ValueError(
f"{execution_engine} engine is not supported for streaming")
Expand Down Expand Up @@ -55,6 +55,12 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):
dtype=torch.long,
device=input_ids.device)
stop_generation = False
engine = None
if "engine" in kwargs.keys():
engine = kwargs["engine"]

if "transformers-neuronx" == engine:
model.reset_generation()

if generic_model_class == "CausalLM":
input_length = input_ids.shape[1]
Expand Down Expand Up @@ -83,10 +89,17 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):

if generic_model_class == "CausalLM":
attention_mask_curr = attention_mask[:, :curr_length]
outputs = model.forward(input_ids=input_ids,
attention_mask=attention_mask_curr,
past_key_values=past_key_values,
use_cache=True)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask_curr,
"past_key_values": past_key_values,
"use_cache": True
}
if "transformers-neuronx" == engine:
model_inputs = model.prepare_inputs_for_generation(
**model_inputs)
model_inputs["return_dict"] = True
outputs = model.forward(**model_inputs)

if generic_model_class == "Seq2SeqLM":
outputs = model.forward(
Expand All @@ -108,7 +121,8 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):

all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, token_ids], dim=1)
past_key_values = outputs.past_key_values
if engine is None:
past_key_values = outputs.past_key_values
new_tokens_count += 1

not_eos_token_ids = (token_ids != tokenizer.eos_token_id).view(
Expand Down Expand Up @@ -136,40 +150,6 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):

yield token_text

@staticmethod
@torch.inference_mode()
def _transformers_neuronx_stream_generator(model, tokenizer, inputs,
**kwargs):
sequence_length = kwargs.get("seq_length",
StreamingUtils.DEFAULT_MAX_NEW_TOKENS)
top_k = kwargs.get("top_k", 50)
tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True)
input_ids = tokenized_inputs["input_ids"]
model.reset()
eos_token_id = model.config.eos_token_id
# populate key/value caches according to the prompt text
_, start = input_ids.shape
position_ids = torch.arange(start, dtype=torch.int32)
next_token_scores = model(input_ids, position_ids)

tokens = [input_ids]
for cur_len in range(start, sequence_length):
# don't sample EOS
next_token_scores[:, eos_token_id] = -float('inf')

# Remove all tokens with a probability less than the last token of the top-k
topk_values, topk_indices = torch.topk(next_token_scores, top_k)
probs = torch.nn.functional.softmax(topk_values, dim=-1)
inputs_in_topk = torch.multinomial(probs,
num_samples=1,
replacement=True)
inputs = torch.gather(topk_indices, 1, inputs_in_topk)
tokens.append(inputs)
token_text = tokenizer.batch_decode(inputs)
position_ids = torch.as_tensor([cur_len], dtype=torch.int32)
next_token_scores = model(inputs, position_ids)
yield token_text

@staticmethod
def _has_met_stopping_criteria(not_eos_token_ids, current_token_count,
max_new_tokens):
Expand Down
6 changes: 2 additions & 4 deletions engines/python/setup/djl_python/transformers-neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import torch
import tempfile
import os
import logging
Expand Down Expand Up @@ -188,10 +187,9 @@ def infer(self, inputs):
if self.enable_streaming:
stream_generator = StreamingUtils.get_stream_generator(
"transformers-neuronx")
model_kwargs["seq_length"] = parameters.get("max_length", 128)
# TODO: switch to new HF model interface
model_kwargs["engine"] = "transformers-neuronx"
outputs.add_stream_content(
stream_generator(self.model.model, self.tokenizer,
stream_generator(self.model, self.tokenizer,
input_text, **model_kwargs))
return outputs

Expand Down

0 comments on commit 5ec164d

Please sign in to comment.