Skip to content

[wip/s2s/pl] attempt to sync metrics in DDP #8269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import numpy as np
import pytorch_lightning as pl
import torch
from datasets import load_metric
from torch.utils.data import DataLoader

from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
ROUGE_KEYS,
AverageMetric,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
Expand Down Expand Up @@ -104,16 +106,19 @@ def __init__(self, hparams, **kwargs):
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
# if prepare_seq2seq_batch raises NotImplementedError, replace this with LegacyDataset
self.dataset_class = Seq2SeqDataset
self.already_saved_batch = False
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
else:
self.eval_max_length = self.model.config.max_length

self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
self.metric_stores = {
k: AverageMetric() for k in self.metric_names + ["gen_time", "gen_len"] + self.loss_names
}

def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
"""A debugging utility"""
Expand Down Expand Up @@ -188,6 +193,7 @@ def validation_epoch_end(self, outputs, prefix="val") -> Dict:
generative_metrics = {
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
}

metric_val = (
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
)
Expand All @@ -196,7 +202,9 @@ def validation_epoch_end(self, outputs, prefix="val") -> Dict:
losses.update(generative_metrics)
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
all_metrics["step_count"] = self.step_count
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
self.metrics[prefix].append(all_metrics) # written to self.metrics_save_path
pl_metrics = {f"pl_{prefix}_avg_{k}": v.compute().item() for k, v in self.metric_stores.items()}
all_metrics.update(pl_metrics)
preds = flatten_list([x["preds"] for x in outputs])
return {
"log": all_metrics,
Expand Down Expand Up @@ -228,6 +236,9 @@ def _generative_step(self, batch: dict) -> dict:
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
for k, v in base_metrics.items():
if k in self.metric_stores:
self.metric_stores[k].update(base_metrics[k])
return base_metrics

def test_step(self, batch, batch_idx):
Expand Down
22 changes: 22 additions & 0 deletions examples/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return loss, nll_loss


from pytorch_lightning.metrics import Metric


class AverageMetric(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state("loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, value):
# preds, target = self._input_format(preds, target)
# assert preds.shape == target.shape
# import ipdb; ipdb.set_trace()
self.loss += torch.tensor(value, dtype=self.loss.dtype, device=self.loss.device)
self.total += 1

# def score(self, *args, **kwargs):
def compute(self):
return self.loss.float() / self.total


def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
Expand Down