Skip to content

Commit 12ca97b

Browse files
committed
ops3
1 parent 8cac6c6 commit 12ca97b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/transformers/generation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,7 +1494,7 @@ def compute_transition_scores(
14941494
return transition_scores
14951495

14961496
def _validate_assistant(self, **gen_mode_kwargs):
1497-
if assistant_model := gen_mode_kwargs.get("assistant_model") is None:
1497+
if (assistant_model := gen_mode_kwargs.get("assistant_model")) is None:
14981498
return
14991499
assistant_tokenizer = gen_mode_kwargs.get("assistant_tokenizer")
15001500

@@ -2157,7 +2157,7 @@ def _get_deprecated_gen_repo(
21572157
)
21582158
return repo
21592159

2160-
def _get_mode_processor_kwargs(self, custom_generate, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
2160+
def _get_mode_processor_kwargs(self, custom_generate, kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
21612161
"""
21622162
Extracts and returns the generation mode and logit processor related keyword arguments from the provided kwargs.
21632163
"""
@@ -2294,7 +2294,7 @@ def generate(
22942294
return custom_generate_function(model=self, **generate_arguments)
22952295

22962296
# 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
2297-
gen_mode_kwargs, logits_processor_kwargs = self._get_mode_processor_kwargs(custom_generate, **kwargs)
2297+
gen_mode_kwargs, logits_processor_kwargs = self._get_mode_processor_kwargs(custom_generate, kwargs)
22982298

22992299
generation_config, model_kwargs = self._prepare_generation_config(
23002300
generation_config, use_model_defaults, **kwargs

0 commit comments

Comments
 (0)