Skip to content

Commit

Permalink
🤐 Fix deprecation warnings (#2392)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Nov 26, 2024
1 parent 16fa13c commit 4f937c7
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 8 deletions.
4 changes: 3 additions & 1 deletion trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ class CPOTrainer(Trainer):

_tag_names = ["trl", "cpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ class DPOTrainer(Trainer):

_tag_names = ["trl", "dpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ class KTOTrainer(Trainer):

_tag_names = ["trl", "kto"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ class OnlineDPOTrainer(Trainer):

_tag_names = ["trl", "online-dpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module],
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ class ORPOTrainer(Trainer):

_tag_names = ["trl", "orpo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase")
class RewardTrainer(Trainer):
_tag_names = ["trl", "reward-trainer"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
class RLOOTrainer(Trainer):
_tag_names = ["trl", "rloo"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
config: RLOOConfig,
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ class SFTTrainer(Trainer):

_tag_names = ["trl", "sft"]

@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True)
@deprecate_kwarg(
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down

0 comments on commit 4f937c7

Please sign in to comment.