From 9f9d997bd86e338695de6fc15fe07249e62e1065 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 25 Jan 2023 11:35:10 +0100 Subject: [PATCH] Support generation config in ORTModel (#651) * support generation config * add can_generate method in ORTModelForConditionalGeneration * trigger actions * fix typog * rollback --- optimum/onnxruntime/modeling_decoder.py | 29 +++++++++++++++++- optimum/onnxruntime/modeling_seq2seq.py | 40 +++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index f47204df48..72c01b2707 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -21,7 +21,7 @@ import numpy as np import torch -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, GenerationConfig from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions @@ -339,6 +339,7 @@ def __init__( use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, **kwargs ): """ @@ -357,6 +358,9 @@ def __init__( The directory under which the model exported to ONNX was saved. preprocessors (`Optional[List]`, defaults to `None`): The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. + generation_config (`Optional[GenerationConfig]`, defaults to `None`): + The generation configuration used by default when calling `generate()`. + Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate. """ # TODO: remove at version 2.0 def show_deprecated_argument(arg_name): @@ -399,6 +403,10 @@ def show_deprecated_argument(arg_name): self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + if generation_config is None: + generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config + @staticmethod def load_model( decoder_path: Union[str, Path], @@ -626,6 +634,20 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir + generation_config = None + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + ) + except OSError: + logger.info("Generation config file not found, using a generation config created from the model config.") + return cls( model[0], config, @@ -633,6 +655,7 @@ def _from_pretrained( use_io_binding=use_io_binding, model_save_dir=model_save_dir, preprocessors=preprocessors, + generation_config=generation_config, ) @classmethod @@ -784,3 +807,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past ) + + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 9f6e9a2dec..ac51f9817d 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -17,7 +17,6 @@ """ import logging -import re import shutil from abc import ABC, abstractmethod from pathlib import Path @@ -26,7 +25,7 @@ import numpy as np import torch -from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq +from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, GenerationConfig from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -728,6 +727,7 @@ def __init__( use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, **kwargs, ): """ @@ -748,6 +748,9 @@ def __init__( The directory under which the model exported to ONNX was saved. preprocessors (`Optional[List]`, defaults to `None`): The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. + generation_config (`Optional[GenerationConfig]`, defaults to `None`): + The generation configuration used by default when calling `generate()`. + Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate. """ # TODO: remove at version 2.0 def show_deprecated_argument(arg_name): @@ -804,6 +807,10 @@ def show_deprecated_argument(arg_name): self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + if generation_config is None: + generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config + @abstractmethod def _initialize_encoder( self, @@ -1076,6 +1083,20 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir + generation_config = None + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + ) + except OSError: + logger.info("Generation config file not found, using a generation config created from the model config.") + return cls( *model[:2], config, @@ -1083,6 +1104,7 @@ def _from_pretrained( use_io_binding=use_io_binding, model_save_dir=model_save_dir, preprocessors=preprocessors, + generation_config=generation_config, ) @classmethod @@ -1178,6 +1200,12 @@ def to(self, device: Union[torch.device, str, int]): return self + def can_generate(self): + logger.warning( + "ORTModelForConditionalGeneration is an abstract class and is not meant to be used for generation. Please use ORTModelForSeq2SeqLM or ORTModelForSpeechSeq2Seq." + ) + return False + class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): """ @@ -1286,6 +1314,10 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: ) return reordered_past + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True + class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin): """ @@ -1398,3 +1430,7 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True