Skip to content

Commit

Permalink
PR changes for changing logger
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
  • Loading branch information
Abhishek-TAMU committed Aug 5, 2024
1 parent da7acc6 commit 8af5792
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 70 deletions.
26 changes: 14 additions & 12 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ def main():
)

# Configure log_level of python native logger.
LOGLEVEL = None
if "log_level" in job_config and job_config["log_level"]:
LOGLEVEL = job_config["log_level"].upper()
logging.basicConfig(level=LOGLEVEL)
else:
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
log_level = job_config.get(
"log_level"
) # this will be set to either the value found or None
if (
not log_level
): # if log level not set by job_config aka by JSON, set it via env var or set default
log_level = os.environ.get("LOG_LEVEL", "WARNING")
log_level = log_level.upper()
logging.basicConfig(level=log_level)

# Configure for Image to get log level as the code after
# launch_command only takes log level from env var LOG_LEVEL and not CLI
os.environ["LOG_LEVEL"] = log_level

args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
except FileNotFoundError as e:
Expand Down Expand Up @@ -114,11 +121,6 @@ def main():
job_config["output_dir"] = tempdir
updated_args = serialize_args(job_config)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args

# Configure for Image to get log level as the code after
# launch_command only takes log level from env var LOG_LEVEL and not CLI
os.environ["LOG_LEVEL"] = LOGLEVEL

launch_command(args)
except subprocess.CalledProcessError as e:
# If the subprocess throws an exception, the base exception is hidden in the
Expand Down
4 changes: 3 additions & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class TrainingArguments(transformers.TrainingArguments):
default="passive",
metadata={
"help": "The log level to adopt during training. \
Possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
'passive' level which doesn't set anything and keeps the \
current log level for the Transformers library (which will be 'warning` by default) \
Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
},
)

Expand Down
68 changes: 34 additions & 34 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def train(
fused_lora and fast_kernels must used together (may change in future). \
"""

train_args, logger = set_log_level(train_args, "sft_trainer_train")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, (float, int))) or (
train_args.num_train_epochs <= 0
Expand Down Expand Up @@ -232,9 +234,9 @@ def train(
)

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logging.info("Max sequence length is %s", max_seq_length)
logger.info("Max sequence length is %s", max_seq_length)
if train_args.max_seq_length > tokenizer.model_max_length:
logging.warning(
logger.warning(
"max_seq_length %s exceeds tokenizer.model_max_length \
%s, using tokenizer.model_max_length %s",
train_args.max_seq_length,
Expand Down Expand Up @@ -270,11 +272,11 @@ def train(

# Configure the collator and validate args related to packing prior to formatting the dataset
if train_args.packing:
logging.info("Packing is set to True")
logger.info("Packing is set to True")
data_collator = None
packing = True
else:
logging.info("Packing is set to False")
logger.info("Packing is set to False")
packing = False

# Validate if data args are set properly
Expand Down Expand Up @@ -339,7 +341,7 @@ def train(
tracker.track(metric=v, name=k, stage="additional_metrics")
tracker.set_params(params=exp_metadata, name="experiment_metadata")
except ValueError as e:
logging.error(
logger.error(
"Exception while saving additional metrics and metadata %s",
repr(e),
)
Expand Down Expand Up @@ -477,7 +479,7 @@ def parse_arguments(parser, json_config=None):
)


def set_log_level(parsed_training_args):
def set_log_level(parsed_training_args, logger_name=__name__):
"""Set log level of python native logger and TF logger via argument from CLI or env variable.
Args:
Expand All @@ -489,28 +491,25 @@ def set_log_level(parsed_training_args):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

# Configure Python native logger log level
# Configure Python native logger and transformers log level
# If CLI arg is passed, assign same log level to python native logger
log_level = "WARNING"
if parsed_training_args.log_level != "passive":
logging.basicConfig(level=parsed_training_args.log_level.upper())
else:
# Assign value of either env var LOG_LEVEL or warning
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
train_logger = logging.getLogger()

# Check if env var TRANSFORMERS_VERBOSITY is not set.
# Else if env var is already set then, log level of transformers is automatically set.
if os.environ.get("TRANSFORMERS_VERBOSITY") is None:

# Check if "--log_level" CLI argument is not used (passive/warning is the default log level)
if parsed_training_args.log_level == "passive":
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
log_level = parsed_training_args.log_level

# If CLI arg not is passed and env var LOG_LEVEL is set,
# assign same log level to both logger
elif os.environ.get("LOG_LEVEL"): # AND parsed_training_args.log_level == "passive"
log_level = os.environ.get("LOG_LEVEL")
parsed_training_args.log_level = (
log_level.lower()
if not os.environ.get("TRANSFORMERS_VERBOSITY")
else os.environ.get("TRANSFORMERS_VERBOSITY")
)

# Set log_level in TrainingArguments
parsed_training_args.log_level = LOGLEVEL.lower()
logging.basicConfig(level=log_level.upper())

train_logger = logging.getLogger()
train_logger = logging.getLogger(logger_name)
return parsed_training_args, train_logger


Expand All @@ -533,9 +532,10 @@ def main(**kwargs): # pylint: disable=unused-argument
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
training_args, _ = set_log_level(training_args)
# training_args, logger = set_log_level(training_args, "sft_trainer_main")
training_args, logger = set_log_level(training_args)

logging.debug(
logger.debug(
"Input args parsed: \
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
Expand All @@ -553,7 +553,7 @@ def main(**kwargs): # pylint: disable=unused-argument
exp_metadata,
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())
write_termination_log(
f"Exception raised during training. This may be a problem with your input: {e}"
)
Expand All @@ -565,12 +565,12 @@ def main(**kwargs): # pylint: disable=unused-argument
try:
metadata = json.loads(exp_metadata)
if metadata is None or not isinstance(metadata, Dict):
logging.warning(
logger.warning(
"metadata cannot be converted to simple k:v dict ignoring"
)
metadata = None
except ValueError as e:
logging.error(
logger.error(
"failed while parsing extra metadata. pass a valid json %s", repr(e)
)

Expand All @@ -593,27 +593,27 @@ def main(**kwargs): # pylint: disable=unused-argument
fusedops_kernels_config=fusedops_kernels_config,
)
except (MemoryError, OutOfMemoryError) as e:
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())
write_termination_log(f"OOM error during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)
except FileNotFoundError as e:
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())
write_termination_log("Unable to load file: {}".format(e))
sys.exit(USER_ERROR_EXIT_CODE)
except HFValidationError as e:
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())
write_termination_log(
f"There may be a problem with loading the model. Exception: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except (TypeError, ValueError, EnvironmentError) as e:
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())
write_termination_log(
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
logger.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during training: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

Expand Down
4 changes: 1 addition & 3 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def __init__(self, tracker_config: AimConfig):
"""
super().__init__(name="aim", tracker_config=tracker_config)
# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
self.logger = logging.getLogger("aimstack_tracker")
self.logger = logging.getLogger(__name__)

def get_hf_callback(self):
"""Returns the aim.hugging_face.AimCallback object associated with this tracker.
Expand Down
4 changes: 1 addition & 3 deletions tuning/trackers/filelogging_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def __init__(self, tracker_config: FileLoggingTrackerConfig):
"""
super().__init__(name="file_logger", tracker_config=tracker_config)
# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
self.logger = logging.getLogger("file_logging_tracker")
self.logger = logging.getLogger(__name__)

def get_hf_callback(self):
"""Returns the FileLoggingCallback object associated with this tracker.
Expand Down
4 changes: 1 addition & 3 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger("tracker_factory")
logger = logging.getLogger(__name__)


# Information about all registered trackers
Expand Down
2 changes: 0 additions & 2 deletions tuning/trainercontroller/controllermetrics/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger(__name__)
METRICS_KEY = "metrics"
LOG_LOSS_KEY = "loss"
Expand Down
2 changes: 0 additions & 2 deletions tuning/trainercontroller/operations/hfcontrols.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .operation import Operation

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger(__name__)


Expand Down
2 changes: 0 additions & 2 deletions tuning/trainercontroller/patience.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure"

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger(__name__)


Expand Down
4 changes: 1 addition & 3 deletions tuning/utils/data_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import torch

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger("data_utils")
logger = logging.getLogger(__name__)


def str_to_torch_dtype(dtype_str: str) -> torch.dtype:
Expand Down
4 changes: 1 addition & 3 deletions tuning/utils/preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
from tuning.utils.data_utils import apply_custom_formatting_template

# Configure log level
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)
logger = logging.getLogger("sft_trainer_preprocessing")
logger = logging.getLogger(__name__)

# In future we may make the fields configurable
JSON_INPUT_KEY = "input"
Expand Down

0 comments on commit 8af5792

Please sign in to comment.