This repository has been archived by the owner on May 28, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
optuna_transformers.py
181 lines (157 loc) · 7.01 KB
/
optuna_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) 2021 Timothy Wolff-Piggott
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""Integration of Optuna and Transformers."""
import logging
import os
from numbers import Number
from typing import Dict, Union
import mlflow
import transformers
from transformers import TrainerControl, TrainerState, TrainingArguments
from hpoflow.optuna_mlflow import OptunaMLflow
_logger = logging.getLogger(__name__)
class OptunaMLflowCallback(transformers.TrainerCallback):
"""Integration of Optuna and Transformers.
Class based on :class:`transformers.TrainerCallback`; integrates with OptunaMLflow to send
the logs to ``MLflow`` and ``Optuna`` during model training.
"""
def __init__(
self,
trial: OptunaMLflow,
log_training_args: bool = True,
log_model_config: bool = True,
):
"""Constructor.
Args:
trial: The OptunaMLflow object.
log_training_args: Whether to log all Transformers TrainingArguments as MLflow params.
log_model_config: Whether to log the Transformers model config as MLflow params.
"""
self._trial = trial
self._log_training_args = log_training_args
self._log_model_config = log_model_config
self._initialized = False
self._log_artifacts = False
def setup(
self,
args: TrainingArguments,
state: TrainerState,
model: Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel, None],
):
"""Setup the optional MLflow integration.
You can set the environment variable ``HF_MLFLOW_LOG_ARTIFACTS``. It is to use
:func:`mlflow.log_artifacts` to log artifacts. This only makes sense if logging to a remote
server, e.g. s3 or GCS. If set to ``True`` or ``1``, will copy whatever is in
TrainerArgument's output_dir to the local or remote artifact storage. Using it without a
remote storage will just copy the files to your artifact location.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1", "YES"}:
self._log_artifacts = True # False is default
if state.is_world_process_zero:
combined_dict = {}
if self._log_training_args:
training_args = args.to_dict()
# create copy so keys do not change while iterating
keys = list(training_args.keys()).copy()
# add prefix
for key in keys:
training_args[f"hf_train_arg_{key}"] = training_args.pop(key)
_logger.debug("Logging training arguments. training_args: %s", training_args)
combined_dict.update(training_args)
if (
model is not None
and self._log_model_config
and hasattr(model, "config")
and model.config is not None # type: ignore
):
model_config = model.config.to_dict() # type: ignore
# create copy so keys do not change while iterating
keys = list(model_config.keys()).copy()
# add prefix
for key in keys:
model_config[f"hf_model_cfg_{key}"] = model_config.pop(key)
_logger.debug("Logging model config. model_config: %s", model_config)
combined_dict.update(model_config)
# TODO: call a DRY function in the mlflow module
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
_logger.warning(
"Trainer is attempting to log a value of "
"'%s' for key '%s' as a parameter. "
"MLflow's log_param() only accepts values no longer than "
"250 characters so we dropped this attribute.",
value,
name,
)
del combined_dict[name]
# TODO: call a DRY function in the mlflow module
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(
0, len(combined_dict_items), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
):
self._trial.log_params(
dict(
combined_dict_items[
i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
]
)
)
self._initialized = True
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel] = None,
**kwargs,
) -> None:
"""Event called at the beginning of training.
Call setup if not yet initialized.
"""
if not self._initialized:
self.setup(args, state, model)
def on_log( # type: ignore
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: Dict[str, Number],
model: Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel] = None,
**kwargs,
):
"""Event called after logging the last logs.
Log all metrics from Transformers logs as MLflow metrics at the appropriate step.
"""
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
metrics_to_log: Dict[str, float] = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics_to_log[k] = v
else:
_logger.warning(
"Trainer is attempting to log a value of "
"'%s' of type %s for key '%s' as a metric. "
"MLflow's log_metric() only accepts float and "
"int types so we dropped this attribute.",
v,
type(v),
k,
)
self._trial.log_metrics(metrics_to_log, step=state.global_step)
def on_train_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
"""Event called at the end of training.
Log the training output as MLflow artifacts if logging artifacts is enabled.
"""
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
_logger.info("Logging artifacts. This may take time.")
mlflow.log_artifacts(args.output_dir)