@@ -89,7 +89,6 @@ class GenerationConfig(PushToHubMixin):
8989 - *multinomial sampling* if `num_beams=1` and `do_sample=True`
9090 - *beam-search decoding* if `num_beams>1` and `do_sample=False`
9191 - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
92- - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
9392 - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
9493
9594 To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
@@ -202,18 +201,10 @@ class GenerationConfig(PushToHubMixin):
202201 bad_words_ids (`list[list[int]]`, *optional*):
203202 List of list of token ids that are not allowed to be generated. Check
204203 [`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
205- force_words_ids (`list[list[int]]` or `list[list[list[int]]]`, *optional*):
206- List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of
207- words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this
208- triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
209- can allow different forms of each word.
210204 renormalize_logits (`bool`, *optional*, defaults to `False`):
211205 Whether to renormalize the logits after applying all the logits processors (including the custom
212206 ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
213207 are normalized but some logit processors break the normalization.
214- constraints (`list[Constraint]`, *optional*):
215- Custom constraints that can be added to the generation to ensure that the output will contain the use of
216- certain tokens as defined by `Constraint` objects, in the most sensible way possible.
217208 forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
218209 The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
219210 multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
@@ -374,9 +365,7 @@ def __init__(self, **kwargs):
374365 self .length_penalty = kwargs .pop ("length_penalty" , 1.0 )
375366 self .no_repeat_ngram_size = kwargs .pop ("no_repeat_ngram_size" , 0 )
376367 self .bad_words_ids = kwargs .pop ("bad_words_ids" , None )
377- self .force_words_ids = kwargs .pop ("force_words_ids" , None )
378368 self .renormalize_logits = kwargs .pop ("renormalize_logits" , False )
379- self .constraints = kwargs .pop ("constraints" , None )
380369 self .forced_bos_token_id = kwargs .pop ("forced_bos_token_id" , None )
381370 self .forced_eos_token_id = kwargs .pop ("forced_eos_token_id" , None )
382371 self .remove_invalid_values = kwargs .pop ("remove_invalid_values" , False )
@@ -434,6 +423,8 @@ def __init__(self, **kwargs):
434423 self .dola_layers = kwargs .pop ("dola_layers" , None )
435424 self .diversity_penalty = kwargs .pop ("diversity_penalty" , 0.0 )
436425 self .num_beam_groups = kwargs .pop ("num_beam_groups" , 1 )
426+ self .constraints = kwargs .pop ("constraints" , None )
427+ self .force_words_ids = kwargs .pop ("force_words_ids" , None )
437428
438429 # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
439430 # interface.
@@ -625,24 +616,6 @@ def validate(self, strict=False):
625616 minor_issues ["length_penalty" ] = single_beam_wrong_parameter_msg .format (
626617 flag_name = "length_penalty" , flag_value = self .length_penalty
627618 )
628- if self .constraints is not None :
629- minor_issues ["constraints" ] = single_beam_wrong_parameter_msg .format (
630- flag_name = "constraints" , flag_value = self .constraints
631- )
632-
633- # 2.3. detect incorrect parameterization specific to advanced beam modes
634- else :
635- # constrained beam search
636- if self .constraints is not None or self .force_words_ids is not None :
637- constrained_wrong_parameter_msg = (
638- "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
639- "However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
640- "mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
641- )
642- if self .do_sample is True :
643- raise ValueError (
644- constrained_wrong_parameter_msg .format (flag_name = "do_sample" , flag_value = self .do_sample )
645- )
646619
647620 # 2.4. check `num_return_sequences`
648621 if self .num_return_sequences != 1 :
0 commit comments