@@ -1494,7 +1494,7 @@ def compute_transition_scores(
14941494 return transition_scores
14951495
14961496 def _validate_assistant (self , ** gen_mode_kwargs ):
1497- if assistant_model := gen_mode_kwargs .get ("assistant_model" ) is None :
1497+ if ( assistant_model := gen_mode_kwargs .get ("assistant_model" ) ) is None :
14981498 return
14991499 assistant_tokenizer = gen_mode_kwargs .get ("assistant_tokenizer" )
15001500
@@ -2157,7 +2157,7 @@ def _get_deprecated_gen_repo(
21572157 )
21582158 return repo
21592159
2160- def _get_mode_processor_kwargs (self , custom_generate , ** kwargs ) -> tuple [dict [str , Any ], dict [str , Any ]]:
2160+ def _get_mode_processor_kwargs (self , custom_generate , kwargs ) -> tuple [dict [str , Any ], dict [str , Any ]]:
21612161 """
21622162 Extracts and returns the generation mode and logit processor related keyword arguments from the provided kwargs.
21632163 """
@@ -2294,7 +2294,7 @@ def generate(
22942294 return custom_generate_function (model = self , ** generate_arguments )
22952295
22962296 # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
2297- gen_mode_kwargs , logits_processor_kwargs = self ._get_mode_processor_kwargs (custom_generate , ** kwargs )
2297+ gen_mode_kwargs , logits_processor_kwargs = self ._get_mode_processor_kwargs (custom_generate , kwargs )
22982298
22992299 generation_config , model_kwargs = self ._prepare_generation_config (
23002300 generation_config , use_model_defaults , ** kwargs
0 commit comments