Skip to content

Commit

Permalink
⚰️ Remove deprecated (huggingface#2485)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Dec 15, 2024
1 parent f68d11f commit 33fb9ef
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 96 deletions.
63 changes: 0 additions & 63 deletions trl/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Union

import yaml
from transformers import HfArgumentParser
from transformers.hf_argparser import DataClass, DataClassType
from transformers.utils.deprecation import deprecate_kwarg


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,45 +58,6 @@ class ScriptArguments:
ignore_bias_buffers: bool = False


class YamlConfigParser:
""" """

def __init__(self) -> None:
warnings.warn(
"The `YamlConfigParser` class is deprecated and will be removed in version 0.14. "
"If you need to use this class, please copy the code to your own project.",
DeprecationWarning,
)

def parse_and_set_env(self, config_path: str) -> dict:
with open(config_path) as yaml_file:
config = yaml.safe_load(yaml_file)

if "env" in config:
env_vars = config.pop("env")
if isinstance(env_vars, dict):
for key, value in env_vars.items():
os.environ[key] = str(value)
else:
raise ValueError("`env` field should be a dict in the YAML file.")

return config

def to_string(self, config):
final_string = ""
for key, value in config.items():
if isinstance(value, (dict, list)):
if len(value) != 0:
value = str(value)
value = value.replace("'", '"')
value = f"'{value}'"
else:
continue

final_string += f"--{key} {value} "
return final_string


def init_zero_verbose():
"""
Perform zero verbose init - use this method on top of the CLI modules to make
Expand Down Expand Up @@ -165,16 +124,9 @@ class MyArguments:
```
"""

@deprecate_kwarg(
"ignore_extra_args",
"0.14.0",
warn_if_greater_or_equal_version=True,
additional_message="Use the `return_remaining_strings` in the `parse_args_and_config` method instead.",
)
def __init__(
self,
dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None,
ignore_extra_args: Optional[bool] = None,
**kwargs,
):
# Make sure dataclass_types is an iterable
Expand All @@ -192,18 +144,6 @@ def __init__(
)

super().__init__(dataclass_types=dataclass_types, **kwargs)
self._ignore_extra_args = ignore_extra_args

def post_process_dataclasses(self, dataclasses):
"""
Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments.
"""
warnings.warn(
"The `post_process_dataclasses` method is deprecated and will be removed in version 0.14. "
"It is no longer functional and can be safely removed from your code.",
DeprecationWarning,
)
return dataclasses

def parse_args_and_config(
self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False
Expand All @@ -216,9 +156,6 @@ def parse_args_and_config(
default values in the dataclasses. Command line arguments can override values set by the config file. The
method also sets any environment variables specified in the `env` field of the config file.
"""
if self._ignore_extra_args is not None:
return_remaining_strings = not self._ignore_extra_args

args = list(args) if args is not None else sys.argv[1:]
if "--config" in args:
# Get the config file path from
Expand Down
4 changes: 0 additions & 4 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
Expand Down Expand Up @@ -106,9 +105,6 @@ class CPOTrainer(Trainer):

_tag_names = ["trl", "cpo"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand Down
14 changes: 0 additions & 14 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional
Expand Down Expand Up @@ -172,7 +171,6 @@ class DPOConfig(TrainingArguments):
truncation_mode: str = "keep_end"
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_target_length: Optional[int] = None # deprecated in favor of max_completion_length
max_completion_length: Optional[int] = None
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
Expand All @@ -194,15 +192,3 @@ class DPOConfig(TrainingArguments):
rpo_alpha: Optional[float] = None
discopop_tau: float = 0.05
use_num_logits_to_keep: bool = False

def __post_init__(self):
if self.max_target_length is not None:
warnings.warn(
"The `max_target_length` argument is deprecated in favor of `max_completion_length` and will be "
"removed in v0.14.",
FutureWarning,
)
if self.max_completion_length is None:
self.max_completion_length = self.max_target_length

return super().__post_init__()
4 changes: 0 additions & 4 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
)
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..models import PreTrainedModelWrapper, create_reference_model
Expand Down Expand Up @@ -318,9 +317,6 @@ class KTOTrainer(Trainer):

_tag_names = ["trl", "kto"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
Expand Down
4 changes: 0 additions & 4 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..models import create_reference_model
Expand Down Expand Up @@ -128,9 +127,6 @@ class OnlineDPOTrainer(Trainer):

_tag_names = ["trl", "online-dpo"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module],
Expand Down
4 changes: 0 additions & 4 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils.deprecation import deprecate_kwarg

from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
Expand Down Expand Up @@ -72,9 +71,6 @@
class RLOOTrainer(Trainer):
_tag_names = ["trl", "rloo"]

@deprecate_kwarg(
"tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
config: RLOOConfig,
Expand Down
3 changes: 0 additions & 3 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.utils.deprecation import deprecate_kwarg

from ..import_utils import is_unsloth_available
from ..trainer.model_config import ModelConfig
Expand Down Expand Up @@ -897,7 +896,6 @@ def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
return kwargs


@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True)
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
Expand Down Expand Up @@ -926,7 +924,6 @@ def get_kbit_device_map() -> Optional[dict[str, int]]:
return None


@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True)
def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
if model_args.use_peft is False:
return None
Expand Down

0 comments on commit 33fb9ef

Please sign in to comment.