From 46ca184f1948c201ef9aefd008e7c20084786288 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Tue, 30 May 2023 13:05:33 -0700 Subject: [PATCH] use huggingface standard generation for tnx --- .../setup/djl_python/streaming_utils.py | 64 +++++++------------ .../setup/djl_python/transformers-neuronx.py | 6 +- 2 files changed, 24 insertions(+), 46 deletions(-) diff --git a/engines/python/setup/djl_python/streaming_utils.py b/engines/python/setup/djl_python/streaming_utils.py index 109cb715a..e0a977285 100644 --- a/engines/python/setup/djl_python/streaming_utils.py +++ b/engines/python/setup/djl_python/streaming_utils.py @@ -23,10 +23,10 @@ def get_stream_generator(execution_engine: str): ## execution_engine passed to this function is not the same engine specified in serving.properties ## in djl-serving. For e.g Accelerate and neuronx use Python as the engine serving.properties ## The engine here refers to backend model parallel framework. - if execution_engine in {"DeepSpeed", "Accelerate"}: + if execution_engine in { + "DeepSpeed", "Accelerate", "transformers-neuronx" + }: return StreamingUtils._hf_model_stream_generator - elif execution_engine == "transformers-neuronx": - return StreamingUtils._transformers_neuronx_stream_generator else: raise ValueError( f"{execution_engine} engine is not supported for streaming") @@ -55,6 +55,12 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs): dtype=torch.long, device=input_ids.device) stop_generation = False + engine = None + if "engine" in kwargs.keys(): + engine = kwargs["engine"] + + if "transformers-neuronx" == engine: + model.reset_generation() if generic_model_class == "CausalLM": input_length = input_ids.shape[1] @@ -83,10 +89,17 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs): if generic_model_class == "CausalLM": attention_mask_curr = attention_mask[:, :curr_length] - outputs = model.forward(input_ids=input_ids, - attention_mask=attention_mask_curr, - past_key_values=past_key_values, - use_cache=True) + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask_curr, + "past_key_values": past_key_values, + "use_cache": True + } + if "transformers-neuronx" == engine: + model_inputs = model.prepare_inputs_for_generation( + **model_inputs) + model_inputs["return_dict"] = True + outputs = model.forward(**model_inputs) if generic_model_class == "Seq2SeqLM": outputs = model.forward( @@ -108,7 +121,8 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs): all_decoder_input_ids = torch.cat( [all_decoder_input_ids, token_ids], dim=1) - past_key_values = outputs.past_key_values + if engine is None: + past_key_values = outputs.past_key_values new_tokens_count += 1 not_eos_token_ids = (token_ids != tokenizer.eos_token_id).view( @@ -136,40 +150,6 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs): yield token_text - @staticmethod - @torch.inference_mode() - def _transformers_neuronx_stream_generator(model, tokenizer, inputs, - **kwargs): - sequence_length = kwargs.get("seq_length", - StreamingUtils.DEFAULT_MAX_NEW_TOKENS) - top_k = kwargs.get("top_k", 50) - tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True) - input_ids = tokenized_inputs["input_ids"] - model.reset() - eos_token_id = model.config.eos_token_id - # populate key/value caches according to the prompt text - _, start = input_ids.shape - position_ids = torch.arange(start, dtype=torch.int32) - next_token_scores = model(input_ids, position_ids) - - tokens = [input_ids] - for cur_len in range(start, sequence_length): - # don't sample EOS - next_token_scores[:, eos_token_id] = -float('inf') - - # Remove all tokens with a probability less than the last token of the top-k - topk_values, topk_indices = torch.topk(next_token_scores, top_k) - probs = torch.nn.functional.softmax(topk_values, dim=-1) - inputs_in_topk = torch.multinomial(probs, - num_samples=1, - replacement=True) - inputs = torch.gather(topk_indices, 1, inputs_in_topk) - tokens.append(inputs) - token_text = tokenizer.batch_decode(inputs) - position_ids = torch.as_tensor([cur_len], dtype=torch.int32) - next_token_scores = model(inputs, position_ids) - yield token_text - @staticmethod def _has_met_stopping_criteria(not_eos_token_ids, current_token_count, max_new_tokens): diff --git a/engines/python/setup/djl_python/transformers-neuronx.py b/engines/python/setup/djl_python/transformers-neuronx.py index 2b7737e5c..5cf8a054b 100644 --- a/engines/python/setup/djl_python/transformers-neuronx.py +++ b/engines/python/setup/djl_python/transformers-neuronx.py @@ -10,7 +10,6 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -import torch import tempfile import os import logging @@ -188,10 +187,9 @@ def infer(self, inputs): if self.enable_streaming: stream_generator = StreamingUtils.get_stream_generator( "transformers-neuronx") - model_kwargs["seq_length"] = parameters.get("max_length", 128) - # TODO: switch to new HF model interface + model_kwargs["engine"] = "transformers-neuronx" outputs.add_stream_content( - stream_generator(self.model.model, self.tokenizer, + stream_generator(self.model, self.tokenizer, input_text, **model_kwargs)) return outputs