Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Features] Trainer add Wandb and Tensorboard #7863

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions paddlenlp/trainer/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

import importlib
import json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr 描述里面,描述一下 使用出来的效果,截图展示一下?

注意同步修改一下中文文档 https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/trainer.md

import numbers
import os
import tempfile
from pathlib import Path

from ..peft import LoRAModel, PrefixModelForCausalLM
from ..transformers import PretrainedModel
Expand All @@ -29,6 +33,16 @@
return importlib.util.find_spec("visualdl") is not None


def is_tensorboard_available():
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None


def is_wandb_available():
if os.getenv("WANDB_DISABLED", "").upper() in {"1", "ON", "YES", "TRUE"}:
return False

Check warning on line 42 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L42

Added line #L42 was not covered by tests
return importlib.util.find_spec("wandb") is not None


def is_ray_available():
return importlib.util.find_spec("ray.air") is not None

Expand All @@ -37,6 +51,10 @@
integrations = []
if is_visualdl_available():
integrations.append("visualdl")
if is_wandb_available():
integrations.append("wandb")
if is_tensorboard_available():
integrations.append("tensorboard")

Check warning on line 57 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L54-L57

Added lines #L54 - L57 were not covered by tests

return integrations

Expand Down Expand Up @@ -137,6 +155,246 @@
self.vdl_writer = None


class TensorBoardCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).

Args:
tb_writer (`SummaryWriter`, *optional*):
The writer to use. Will instantiate one if not set.
"""

def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
if not has_tensorboard:
raise RuntimeError(

Check warning on line 170 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L170

Added line #L170 was not covered by tests
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
" install tensorboardX."
)
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401

self._SummaryWriter = SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter

Check warning on line 181 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L179-L181

Added lines #L179 - L181 were not covered by tests

self._SummaryWriter = SummaryWriter
except ImportError:
self._SummaryWriter = None

Check warning on line 185 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L183-L185

Added lines #L183 - L185 were not covered by tests
else:
self._SummaryWriter = None

Check warning on line 187 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L187

Added line #L187 was not covered by tests
self.tb_writer = tb_writer

def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)

def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return

Check warning on line 197 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L197

Added line #L197 was not covered by tests

log_dir = None

if self.tb_writer is None:
self._init_summary_writer(args, log_dir)

if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)

Check warning on line 210 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L207-L210

Added lines #L207 - L210 were not covered by tests

def on_log(self, args, state, control, logs=None, **kwargs):
if not state.is_world_process_zero:
return

Check warning on line 214 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L214

Added line #L214 was not covered by tests

if self.tb_writer is None:
self._init_summary_writer(args)

Check warning on line 217 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L217

Added line #L217 was not covered by tests

if self.tb_writer is not None:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
logger.warning(

Check warning on line 225 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L225

Added line #L225 was not covered by tests
"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute."
)
self.tb_writer.flush()

def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
self.tb_writer = None


class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
"""

def __init__(self):
has_wandb = is_wandb_available()
if not has_wandb:
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")

Check warning on line 247 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L247

Added line #L247 was not covered by tests
if has_wandb:
import wandb

self._wandb = wandb
self._initialized = False
# log model
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()

def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (*wandb*) integration.

One can subclass and override this method to customize the setup if needed.
variables:
Environment:
- **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
with [`TrainingArguments.load_best_model_at_end`] to upload best model.
- **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
Set this to a custom string to store results in a different project.
- **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
"""
if self._wandb is None:
return

Check warning on line 277 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L277

Added line #L277 was not covered by tests

# Check if a Weights & Biases (wandb) API key is provided in the training arguments
if args.wandb_api_key:
if self._wandb.api.api_key:
logger.warning(

Check warning on line 282 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L281-L282

Added lines #L281 - L282 were not covered by tests
"A Weights & Biases API key is already configured in the environment. "
"However, the training argument 'wandb_api_key' will take precedence. "
)
self._wandb.login(key=args.wandb_api_key)

Check warning on line 286 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L286

Added line #L286 was not covered by tests

self._initialized = True

if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_dict()}

if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}

Check warning on line 298 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L297-L298

Added lines #L297 - L298 were not covered by tests
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
init_args["name"] = trial_name
init_args["group"] = args.run_name
else:
if not (args.run_name is None or args.run_name == args.output_dir):
init_args["name"] = args.run_name

Check warning on line 306 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L305-L306

Added lines #L305 - L306 were not covered by tests
init_args["dir"] = args.logging_dir
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "PaddleNLP"),
**init_args,
)
# add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)

# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)

# keep track of model topology and gradients
_watch_model = os.getenv("WANDB_WATCH", "false")
if _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))

Check warning on line 324 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L324

Added line #L324 was not covered by tests
self._wandb.run._label(code="transformers_trainer")

def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return

Check warning on line 329 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L329

Added line #L329 was not covered by tests
if not self._initialized:
self.setup(args, state, model, **kwargs)

def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return

Check warning on line 335 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L335

Added line #L335 was not covered by tests
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer

fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
metadata = (
{
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
logger.info("Logging model artifacts. ...")

model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())

self._wandb.run.log_artifact(artifact)

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if self._wandb is None:
return

Check warning on line 371 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L371

Added line #L371 was not covered by tests
if not self._initialized:
self.setup(args, state, model)

Check warning on line 373 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L373

Added line #L373 was not covered by tests
if state.is_world_process_zero:
logs = rewrite_logs(logs)
self._wandb.log({**logs, "train/global_step": state.global_step})

def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {

Check warning on line 380 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L379-L380

Added lines #L379 - L380 were not covered by tests
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (

Check warning on line 388 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L385-L388

Added lines #L385 - L388 were not covered by tests
f"checkpoint-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])

Check warning on line 395 in paddlenlp/trainer/integrations.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/integrations.py#L393-L395

Added lines #L393 - L395 were not covered by tests


class AutoNLPCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [`Ray Tune`] for [`AutoNLP`]
Expand All @@ -163,6 +421,8 @@
INTEGRATION_TO_CALLBACK = {
"visualdl": VisualDLCallback,
"autonlp": AutoNLPCallback,
"wandb": WandbCallback,
"tensorboard": TensorBoardCallback,
}


Expand Down
11 changes: 9 additions & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,11 @@ class TrainingArguments:
than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an
instance of `Dataset`.
report_to (`str` or `List[str]`, *optional*, defaults to `"visualdl"`):
The list of integrations to report the results and logs to. Supported platforms is `"visualdl"`.
The list of integrations to report the results and logs to.
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
`"none"` for no integrations.
wandb_api_key (`str`, *optional*):
Weights & Biases (WandB) API key(s) for authentication with the WandB service.
resume_from_checkpoint (`str`, *optional*):
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
Expand Down Expand Up @@ -695,6 +698,10 @@ class TrainingArguments:
report_to: Optional[List[str]] = field(
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
)
wandb_api_key: Optional[str] = field(
default=None,
metadata={"help": "Weights & Biases (WandB) API key(s) for authentication with the WandB service."},
)
gongel marked this conversation as resolved.
Show resolved Hide resolved
resume_from_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
Expand Down Expand Up @@ -1318,7 +1325,7 @@ def is_segment_parallel_supported():
"integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as "
"now. You should start updating your code and make this info disappear :-)."
)
self.report_to = "all"
self.report_to = "visualdl"
gongel marked this conversation as resolved.
Show resolved Hide resolved
if self.report_to == "all" or self.report_to == ["all"]:
# Import at runtime to avoid a circular import.
from .integrations import get_available_reporting_integrations
Expand Down
5 changes: 4 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ soundfile
librosa
numpy==1.23.5
rouge
tiktoken
tiktoken
visualdl
wandb
tensorboard
Loading