Skip to content

Commit

Permalink
Merge pull request #270 from Abhishek-TAMU/fix_logging
Browse files Browse the repository at this point in the history
Fix: Removal of transformers logger and addition of python native logger
  • Loading branch information
Abhishek-TAMU authored Aug 9, 2024
2 parents bb0caf9 + 0866bce commit ee25de4
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 48 deletions.
16 changes: 12 additions & 4 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def get_base_model_from_adapter_config(adapter_config):


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)

if not os.getenv("TERMINATION_LOG_FILE"):
os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG

Expand All @@ -80,6 +77,18 @@ def main():
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)

# Configure log_level of python native logger.
# CLI arg takes precedence over env var. And if neither is set, we use default "WARNING"
log_level = job_config.get(
"log_level"
) # this will be set to either the value found or None
if (
not log_level
): # if log level not set by job_config aka by JSON, set it via env var or set default
log_level = os.environ.get("LOG_LEVEL", "WARNING")
log_level = log_level.upper()
logging.basicConfig(level=log_level)

args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
except FileNotFoundError as e:
Expand Down Expand Up @@ -109,7 +118,6 @@ def main():
job_config["output_dir"] = tempdir
updated_args = serialize_args(job_config)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args

launch_command(args)
except subprocess.CalledProcessError as e:
# If the subprocess throws an exception, the base exception is hidden in the
Expand Down
84 changes: 84 additions & 0 deletions tests/utils/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
from unittest import mock
import copy
import logging
import os

# First Party
from tests.test_sft_trainer import TRAIN_ARGS

# Local
from tuning.utils.logging import set_log_level


@mock.patch.dict(os.environ, {}, clear=True)
def test_set_log_level_for_logger_default():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when no env var or CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.WARNING
assert training_args.log_level == "passive"


@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True)
def test_set_log_level_for_logger_with_env_var():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when env var LOG_LEVEL is used
"""

train_args_env = copy.deepcopy(TRAIN_ARGS)
training_args, logger = set_log_level(train_args_env)
assert logger.getEffectiveLevel() == logging.INFO
assert training_args.log_level == "info"


@mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True)
def test_set_log_level_for_logger_with_set_verbosity_and_cli():
"""
Ensure that the correct log level is being set for python native logger and
log_level of transformers logger is unchanged when env var TRANSFORMERS_VERBOSITY is used
and CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.log_level = "error"
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"


@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True)
def test_set_log_level_for_logger_with_env_var_and_cli():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when env var LOG_LEVEL is used and CLI flag is passed.
In this case, CLI arg takes precedence over the set env var LOG_LEVEL.
"""

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.log_level = "error"
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"
9 changes: 9 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ class TrainingArguments(transformers.TrainingArguments):
+ "Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
log_level: str = field(
default="passive",
metadata={
"help": "The log level to adopt during training. \
By default, 'passive' level is set which keeps the \
current log level for the Transformers library (which will be 'warning` by default) \
Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
},
)


@dataclass
Expand Down
12 changes: 7 additions & 5 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LlamaTokenizerFast,
TrainerCallback,
)
from transformers.utils import is_accelerate_available, logging
from transformers.utils import is_accelerate_available
from trl import SFTConfig, SFTTrainer
import fire
import transformers
Expand All @@ -60,6 +60,7 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
Expand Down Expand Up @@ -111,7 +112,7 @@ def train(
fused_lora and fast_kernels must used together (may change in future). \
"""

logger = logging.get_logger("sft_trainer")
train_args, logger = set_log_level(train_args, "sft_trainer_train")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, (float, int))) or (
Expand Down Expand Up @@ -479,11 +480,8 @@ def parse_arguments(parser, json_config=None):


def main(**kwargs): # pylint: disable=unused-argument
logger = logging.get_logger("__main__")

parser = get_parser()
job_config = get_json_config()
logger.debug("Input args parsed: %s", job_config)
# accept arguments via command-line or JSON
try:
(
Expand All @@ -498,6 +496,10 @@ def main(**kwargs): # pylint: disable=unused-argument
fusedops_kernels_config,
exp_metadata,
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
training_args, logger = set_log_level(training_args, __name__)

logger.debug(
"Input args parsed: \
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
Expand Down
5 changes: 3 additions & 2 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

# Standard
import json
import logging
import os

# Third Party
from aim.hugging_face import AimCallback # pylint: disable=import-error
from transformers.utils import logging

# Local
from .tracker import Tracker
Expand Down Expand Up @@ -99,7 +99,8 @@ def __init__(self, tracker_config: AimConfig):
information about the repo or the server and port where aim db is present.
"""
super().__init__(name="aim", tracker_config=tracker_config)
self.logger = logging.get_logger("aimstack_tracker")
# Get logger with root log level
self.logger = logging.getLogger()

def get_hf_callback(self):
"""Returns the aim.hugging_face.AimCallback object associated with this tracker.
Expand Down
5 changes: 3 additions & 2 deletions tuning/trackers/filelogging_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# Standard
from datetime import datetime
import json
import logging
import os

# Third Party
from transformers import TrainerCallback
from transformers.utils import logging

# Local
from .tracker import Tracker
Expand Down Expand Up @@ -80,7 +80,8 @@ def __init__(self, tracker_config: FileLoggingTrackerConfig):
which contains the location of file where logs are recorded.
"""
super().__init__(name="file_logger", tracker_config=tracker_config)
self.logger = logging.get_logger("file_logging_tracker")
# Get logger with root log level
self.logger = logging.getLogger()

def get_hf_callback(self):
"""Returns the FileLoggingCallback object associated with this tracker.
Expand Down
15 changes: 6 additions & 9 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@

# Standard
import dataclasses
import logging

# Third Party
from transformers.utils import logging
from transformers.utils.import_utils import _is_package_available

# Local
from .filelogging_tracker import FileLoggingTracker
from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory

logger = logging.get_logger("tracker_factory")


# Information about all registered trackers
AIMSTACK_TRACKER = "aim"
FILE_LOGGING_TRACKER = "file_logger"
Expand Down Expand Up @@ -54,9 +51,9 @@ def _register_aim_tracker():
AimTracker = _get_tracker_class(AimStackTracker, AimConfig)

REGISTERED_TRACKERS[AIMSTACK_TRACKER] = AimTracker
logger.info("Registered aimstack tracker")
logging.info("Registered aimstack tracker")
else:
logger.info(
logging.info(
"Not registering Aimstack tracker due to unavailablity of package.\n"
"Please install aim if you intend to use it.\n"
"\t pip install aim"
Expand All @@ -72,14 +69,14 @@ def _is_tracker_installed(name):
def _register_file_logging_tracker():
FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig)
REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker
logger.info("Registered file logging tracker")
logging.info("Registered file logging tracker")


# List of Available Trackers
# file_logger - Logs loss to a file
# aim - Aimstack Tracker
def _register_trackers():
logger.info("Registering trackers")
logging.info("Registering trackers")
if AIMSTACK_TRACKER not in REGISTERED_TRACKERS:
_register_aim_tracker()
if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS:
Expand Down Expand Up @@ -142,7 +139,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory):
e = "Requested Tracker {} not found. List trackers available for use is - {} ".format(
name, available
)
logger.error(e)
logging.error(e)
raise ValueError(e)

meta = REGISTERED_TRACKERS[name]
Expand Down
5 changes: 0 additions & 5 deletions tuning/trainercontroller/controllermetrics/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
# Standard
from typing import Any

# Third Party
from transformers.utils import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

logger = logging.get_logger(__name__)


class EvalMetrics(MetricHandler):
"""Implements the controller metric which exposes the evaluation metrics"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

# Third Party
from transformers import TrainerState
from transformers.utils import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

logger = logging.get_logger(__name__)
METRICS_KEY = "metrics"
LOG_LOSS_KEY = "loss"
TRAINING_LOSS_KEY = "training_loss"
Expand Down
6 changes: 2 additions & 4 deletions tuning/trainercontroller/operations/hfcontrols.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# Standard
from dataclasses import fields
import inspect
import logging
import re

# Third Party
from transformers import TrainerControl
from transformers.utils import logging

# Local
from .operation import Operation

logger = logging.get_logger(__name__)


class HFControls(Operation):
"""Implements the control actions for the HuggingFace controls in
Expand Down Expand Up @@ -39,7 +37,7 @@ def control_action(self, control: TrainerControl, **kwargs):
control: TrainerControl. Data class for controls.
kwargs: List of arguments (key, value)-pairs
"""
logger.debug("Arguments passed to control_action: %s", repr(kwargs))
logging.debug("Arguments passed to control_action: %s", repr(kwargs))
frame_info = inspect.currentframe().f_back
arg_values = inspect.getargvalues(frame_info)
setattr(control, arg_values.locals["action"], True)
10 changes: 4 additions & 6 deletions tuning/trainercontroller/patience.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
from transformers.utils import logging
# Standard
import logging

# Resets the patience if the rule outcome happens to be false.
# Here, the expectation is to have unbroken "True"s for patience
Expand All @@ -31,8 +31,6 @@
# will be exceeded afer the fifth event.
MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure"

logger = logging.get_logger(__name__)


class PatienceControl:
"""Implements the patience control for every rule"""
Expand All @@ -51,7 +49,7 @@ def should_tolerate(
elif self._mode == MODE_RESET_ON_FAILURE:
self._patience_counter = 0
if self._patience_counter <= self._patience_threshold:
logger.debug(
logging.debug(
"Control {} triggered on event {}: "
"Enforcing patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand All @@ -62,7 +60,7 @@ def should_tolerate(
)
)
return True
logger.debug(
logging.debug(
"Control {} triggered on event {}: "
"Exceeded patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand Down
Loading

0 comments on commit ee25de4

Please sign in to comment.