From 78045dedc8678af04f4e35ffe63f37be196a435b Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Mon, 8 Jul 2024 01:59:26 +0200 Subject: [PATCH] Fix `TRL_USE_RICH` environment variable handling (#1808) * Add `strtobool` custom implementation from `distutils` * Fix `TRL_USE_RICH` handling via `strtobool` * Run `make precommit` --- examples/scripts/dpo.py | 5 +++-- examples/scripts/sft.py | 5 +++-- examples/scripts/vsft_llava.py | 5 +++-- trl/commands/cli.py | 5 +++-- trl/env_utils.py | 34 ++++++++++++++++++++++++++++++++++ 5 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 trl/env_utils.py diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index baec049c81..0294932a49 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -55,9 +55,10 @@ import os from contextlib import nullcontext -TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) - from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser +from trl.env_utils import strtobool + +TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 38166264d6..1df011a4af 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -51,9 +51,10 @@ import os from contextlib import nullcontext -TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) - from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser +from trl.env_utils import strtobool + +TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() diff --git a/examples/scripts/vsft_llava.py b/examples/scripts/vsft_llava.py index 85cb98d5f3..32e9e0b804 100644 --- a/examples/scripts/vsft_llava.py +++ b/examples/scripts/vsft_llava.py @@ -68,9 +68,10 @@ import os from contextlib import nullcontext -TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) - from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser +from trl.env_utils import strtobool + +TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) if TRL_USE_RICH: init_zero_verbose() diff --git a/trl/commands/cli.py b/trl/commands/cli.py index 46b761473a..f5695c233d 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -41,8 +41,9 @@ def main(): trl_examples_dir = os.path.dirname(__file__) - # Force-use rich - os.environ["TRL_USE_RICH"] = "1" + # Force-use rich if the `TRL_USE_RICH` env var is not set + if "TRL_USE_RICH" not in os.environ: + os.environ["TRL_USE_RICH"] = "1" if command_name == "chat": command = f""" diff --git a/trl/env_utils.py b/trl/env_utils.py new file mode 100644 index 0000000000..64e98199e0 --- /dev/null +++ b/trl/env_utils.py @@ -0,0 +1,34 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Function `strtobool` copied and adapted from `distutils` (as deprected +# in Python 3.10). +# Reference: https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321 + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to True or False booleans. + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. + + Raises: + ValueError: if 'val' is anything else. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + if val in ("n", "no", "f", "false", "off", "0"): + return False + raise ValueError(f"Invalid truth value, it should be a string but {val} was provided instead.")