Skip to content

Commit

Permalink
PR changes
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
  • Loading branch information
Abhishek-TAMU committed Aug 6, 2024
1 parent 024b12e commit cfeb709
Show file tree
Hide file tree
Showing 14 changed files with 123 additions and 138 deletions.
4 changes: 0 additions & 4 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ def main():
log_level = log_level.upper()
logging.basicConfig(level=log_level)

# Configure for Image to get log level as the code after
# launch_command only takes log level from env var LOG_LEVEL and not CLI
# os.environ["LOG_LEVEL"] = log_level

args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
except FileNotFoundError as e:
Expand Down
119 changes: 60 additions & 59 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

# Standard
from unittest import mock
import copy
import json
import logging
Expand Down Expand Up @@ -44,6 +45,7 @@
# Local
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.utils.logging import set_log_level

MODEL_ARGS = configs.ModelArguments(
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
Expand Down Expand Up @@ -78,6 +80,64 @@
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)


@mock.patch.dict(os.environ, {}, clear=True)
def test_set_log_level_for_logger_default():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when no env var or CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.WARNING
assert training_args.log_level == "passive"


@mock.patch.dict(os.environ, {}, clear=True)
def test_set_log_level_for_logger_with_env_var():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when env var LOG_LEVEL is used
"""

train_args_env = copy.deepcopy(TRAIN_ARGS)
os.environ["LOG_LEVEL"] = "info"
training_args, logger = set_log_level(train_args_env)
assert logger.getEffectiveLevel() == logging.INFO
assert training_args.log_level == "info"


@mock.patch.dict(os.environ, {}, clear=True)
def test_set_log_level_for_logger_with_set_verbosity_and_cli():
"""
Ensure that the correct log level is being set for python native logger and
log_level of transformers logger is unchanged when env var TRANSFORMERS_VERBOSITY is used
and CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
os.environ["TRANSFORMERS_VERBOSITY"] = "info"
train_args.log_level = "error"
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"


@mock.patch.dict(os.environ, {}, clear=True)
def test_set_log_level_for_logger_with_env_var_and_cli():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when env var LOG_LEVEL is used and CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
os.environ["LOG_LEVEL"] = "info"
train_args.log_level = "error"
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"


def test_run_train_requires_output_dir():
"""Check fails when output dir not provided."""
updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS)
Expand Down Expand Up @@ -809,62 +869,3 @@ def test_pretokenized_dataset_wrong_format():
# is essentially swallowing a KeyError here.
with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)


def test_set_log_level_for_logger_default():
"""
Ensure that the correct log level is being set
for python native logger and transformers logger
"""

# Set env var TRANSFORMERS_VERBOSITY as None and test
os.unsetenv("TRANSFORMERS_VERBOSITY")
os.unsetenv("LOG_LEVEL")
os.unsetenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")
os.unsetenv("SFT_TRAINER_CONFIG_JSON_PATH")
train_args = copy.deepcopy(TRAIN_ARGS)

# TEST IF NO ENV VAR ARE SET AND NO CLI ARGUMENT IS PASSED
training_args, logger = sft_trainer.set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.WARNING
assert training_args.log_level in ["passive", "warning"]


def test_set_log_level_for_logger_with_env_var():
"""
Ensure that the correct log level is being set
for python native logger and transformers logger
"""

os.unsetenv("TRANSFORMERS_VERBOSITY")
os.unsetenv("LOG_LEVEL")
os.unsetenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")
os.unsetenv("SFT_TRAINER_CONFIG_JSON_PATH")
train_args_env = copy.deepcopy(TRAIN_ARGS)

# TEST IF LOG_LEVEL ENV VAR IS SET AND NO CLI ARGUMENT IS PASSED
os.environ["LOG_LEVEL"] = "info"
train_args_env.log_level = "passive" # Default
training_args, logger = sft_trainer.set_log_level(train_args_env)
assert logger.getEffectiveLevel() == logging.INFO
assert training_args.log_level == "info"


def test_set_log_level_for_logger_with_env_var_and_cli():
"""
Ensure that the correct log level is being set
for python native logger and transformers logger
"""

os.unsetenv("TRANSFORMERS_VERBOSITY")
os.unsetenv("LOG_LEVEL")
os.unsetenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")
os.unsetenv("SFT_TRAINER_CONFIG_JSON_PATH")
train_args = copy.deepcopy(TRAIN_ARGS)

# TEST IF LOG_LEVEL ENV VAR IS SET AND --log_level CLI ARGUMENT IS PASSED
os.environ["LOG_LEVEL"] = "info"
train_args.log_level = "error"
training_args, logger = sft_trainer.set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"
2 changes: 1 addition & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class TrainingArguments(transformers.TrainingArguments):
default="passive",
metadata={
"help": "The log level to adopt during training. \
'passive' level which doesn't set anything and keeps the \
By default, 'passive' level is set which keeps the \
current log level for the Transformers library (which will be 'warning` by default) \
Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
},
Expand Down
39 changes: 2 additions & 37 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from typing import Dict, List, Optional, Union
import dataclasses
import json
import logging
import os
import sys
import time
import traceback
Expand Down Expand Up @@ -62,6 +60,7 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
Expand Down Expand Up @@ -479,40 +478,6 @@ def parse_arguments(parser, json_config=None):
)


def set_log_level(parsed_training_args, logger_name=__name__):
"""Set log level of python native logger and TF logger via argument from CLI or env variable.
Args:
parsed_training_args
Training arguments for training model.
"""

# Clear any existing handlers if necessary
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

# Configure Python native logger and transformers log level
# If CLI arg is passed, assign same log level to python native logger
log_level = "WARNING"
if parsed_training_args.log_level != "passive":
log_level = parsed_training_args.log_level

# If CLI arg not is passed and env var LOG_LEVEL is set,
# assign same log level to both logger
elif os.environ.get("LOG_LEVEL"): # AND parsed_training_args.log_level == "passive"
log_level = os.environ.get("LOG_LEVEL")
parsed_training_args.log_level = (
log_level.lower()
if not os.environ.get("TRANSFORMERS_VERBOSITY")
else os.environ.get("TRANSFORMERS_VERBOSITY")
)

logging.basicConfig(level=log_level.upper())

train_logger = logging.getLogger(logger_name)
return parsed_training_args, train_logger


def main(**kwargs): # pylint: disable=unused-argument
parser = get_parser()
job_config = get_json_config()
Expand All @@ -532,7 +497,7 @@ def main(**kwargs): # pylint: disable=unused-argument
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
training_args, logger = set_log_level(training_args)
training_args, logger = set_log_level(training_args, __name__)

logger.debug(
"Input args parsed: \
Expand Down
4 changes: 2 additions & 2 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(self, tracker_config: AimConfig):
information about the repo or the server and port where aim db is present.
"""
super().__init__(name="aim", tracker_config=tracker_config)
# Configure log level
self.logger = logging.getLogger(__name__)
# Get logger with root log level
self.logger = logging.getLogger()

def get_hf_callback(self):
"""Returns the aim.hugging_face.AimCallback object associated with this tracker.
Expand Down
4 changes: 2 additions & 2 deletions tuning/trackers/filelogging_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(self, tracker_config: FileLoggingTrackerConfig):
which contains the location of file where logs are recorded.
"""
super().__init__(name="file_logger", tracker_config=tracker_config)
# Configure log level
self.logger = logging.getLogger(__name__)
# Get logger with root log level
self.logger = logging.getLogger()

def get_hf_callback(self):
"""Returns the FileLoggingCallback object associated with this tracker.
Expand Down
14 changes: 5 additions & 9 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
from .filelogging_tracker import FileLoggingTracker
from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory

# Configure log level
logger = logging.getLogger(__name__)


# Information about all registered trackers
AIMSTACK_TRACKER = "aim"
FILE_LOGGING_TRACKER = "file_logger"
Expand Down Expand Up @@ -55,9 +51,9 @@ def _register_aim_tracker():
AimTracker = _get_tracker_class(AimStackTracker, AimConfig)

REGISTERED_TRACKERS[AIMSTACK_TRACKER] = AimTracker
logger.info("Registered aimstack tracker")
logging.info("Registered aimstack tracker")
else:
logger.info(
logging.info(
"Not registering Aimstack tracker due to unavailablity of package.\n"
"Please install aim if you intend to use it.\n"
"\t pip install aim"
Expand All @@ -73,14 +69,14 @@ def _is_tracker_installed(name):
def _register_file_logging_tracker():
FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig)
REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker
logger.info("Registered file logging tracker")
logging.info("Registered file logging tracker")


# List of Available Trackers
# file_logger - Logs loss to a file
# aim - Aimstack Tracker
def _register_trackers():
logger.info("Registering trackers")
logging.info("Registering trackers")
if AIMSTACK_TRACKER not in REGISTERED_TRACKERS:
_register_aim_tracker()
if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS:
Expand Down Expand Up @@ -143,7 +139,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory):
e = "Requested Tracker {} not found. List trackers available for use is - {} ".format(
name, available
)
logger.error(e)
logging.error(e)
raise ValueError(e)

meta = REGISTERED_TRACKERS[name]
Expand Down
4 changes: 0 additions & 4 deletions tuning/trainercontroller/controllermetrics/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@

# Standard
from typing import Any
import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

# Configure log level
logger = logging.getLogger(__name__)


class EvalMetrics(MetricHandler):
"""Implements the controller metric which exposes the evaluation metrics"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# Standard
from collections import deque
from typing import Any
import logging

# Third Party
from transformers import TrainerState
Expand All @@ -27,7 +26,6 @@
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

# Configure log level
logger = logging.getLogger(__name__)
METRICS_KEY = "metrics"
LOG_LOSS_KEY = "loss"
TRAINING_LOSS_KEY = "training_loss"
Expand Down
5 changes: 1 addition & 4 deletions tuning/trainercontroller/operations/hfcontrols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
# Local
from .operation import Operation

# Configure log level
logger = logging.getLogger(__name__)


class HFControls(Operation):
"""Implements the control actions for the HuggingFace controls in
Expand Down Expand Up @@ -40,7 +37,7 @@ def control_action(self, control: TrainerControl, **kwargs):
control: TrainerControl. Data class for controls.
kwargs: List of arguments (key, value)-pairs
"""
logger.debug("Arguments passed to control_action: %s", repr(kwargs))
logging.debug("Arguments passed to control_action: %s", repr(kwargs))
frame_info = inspect.currentframe().f_back
arg_values = inspect.getargvalues(frame_info)
setattr(control, arg_values.locals["action"], True)
7 changes: 2 additions & 5 deletions tuning/trainercontroller/patience.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@
# will be exceeded afer the fifth event.
MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure"

# Configure log level
logger = logging.getLogger(__name__)


class PatienceControl:
"""Implements the patience control for every rule"""
Expand All @@ -52,7 +49,7 @@ def should_tolerate(
elif self._mode == MODE_RESET_ON_FAILURE:
self._patience_counter = 0
if self._patience_counter <= self._patience_threshold:
logger.debug(
logging.debug(
"Control {} triggered on event {}: "
"Enforcing patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand All @@ -63,7 +60,7 @@ def should_tolerate(
)
)
return True
logger.debug(
logging.debug(
"Control {} triggered on event {}: "
"Exceeded patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand Down
5 changes: 1 addition & 4 deletions tuning/utils/data_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
# Third Party
import torch

# Configure log level
logger = logging.getLogger(__name__)


def str_to_torch_dtype(dtype_str: str) -> torch.dtype:
"""Given a string representation of a Torch data type, convert it to the actual torch dtype.
Expand All @@ -36,7 +33,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype:
"""
dt = getattr(torch, dtype_str, None)
if not isinstance(dt, torch.dtype):
logger.error(" ValueError: Unrecognized data type of a torch.Tensor")
logging.error(" ValueError: Unrecognized data type of a torch.Tensor")
raise ValueError("Unrecognized data type of a torch.Tensor")
return dt

Expand Down
Loading

0 comments on commit cfeb709

Please sign in to comment.