Skip to content

Commit

Permalink
Reproduce #1781. Add Weights and Biases support
Browse files Browse the repository at this point in the history
Summary:

Fixes facebookresearch/fairseq#1790.

Reviewed By: alexeib

Differential Revision: D24579153

fbshipit-source-id: 74a30effa164db9d6376554376e36b1f47618899

Co-authored-by: Nikolay Korolev <korolevns98@gmail.com>
Co-authored-by: Vlad Lyalin <Guitaricet@gmail.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Nov 4, 2020
1 parent dd52ed0 commit 1a709b2
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,6 @@ data-bin/

# Experimental Folder
experimental/*

# Weights and Biases logs
wandb/
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ common:
log_interval: 100
log_format: null
tensorboard_logdir: null
wandb_project: null
seed: 1
cpu: false
tpu: false
Expand Down
6 changes: 6 additions & 0 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class CommonConfig(FairseqDataclass):
"of running tensorboard (default: no tensorboard logging)"
},
)
wandb_project: Optional[str] = field(
default=None,
metadata={
"help": "Weights and Biases project name to use for logging"
},
)
seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"}
)
Expand Down
51 changes: 51 additions & 0 deletions fairseq/logging/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def progress_bar(
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
default_log_format: str = "tqdm",
wandb_project: Optional[str] = None,
):
if log_format is None:
log_format = default_log_format
Expand Down Expand Up @@ -60,6 +61,9 @@ def progress_bar(
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)

if wandb_project:
bar = WandBProgressBarWrapper(bar, wandb_project)

return bar


Expand Down Expand Up @@ -353,3 +357,50 @@ def _log_to_tensorboard(self, stats, tag=None, step=None):
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
writer.flush()


try:
import wandb
except ImportError:
wandb = None


class WandBProgressBarWrapper(BaseProgressBar):
"""Log to Weights & Biases."""

def __init__(self, wrapped_bar, wandb_project):
self.wrapped_bar = wrapped_bar
if wandb is None:
logger.warning('wandb not found, pip install wandb')
return

# reinit=False to ensure if wandb.init() is called multiple times
# within one process it still references the same run
wandb.init(project=wandb_project, reinit=False)

def __iter__(self):
return iter(self.wrapped_bar)

def log(self, stats, tag=None, step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)

def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)

def _log_to_wandb(self, stats, tag=None, step=None):
if wandb is None:
return
if step is None:
step = stats['num_updates']

prefix = '' if tag is None else tag + '/'

for key in stats.keys() - {'num_updates'}:
if isinstance(stats[key], AverageMeter):
wandb.log({prefix + key: stats[key].val}, step=step)
elif isinstance(stats[key], Number):
wandb.log({prefix + key: stats[key]}, step=step)
10 changes: 8 additions & 2 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr)
tensorboard_logdir=(
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
),
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
),
)

trainer.begin_epoch(epoch_itr.epoch)
Expand Down Expand Up @@ -307,7 +310,10 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i
tensorboard_logdir=(
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
),
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
),
)

# create a new root metrics aggregator so validation metrics
Expand Down

0 comments on commit 1a709b2

Please sign in to comment.