Skip to content

Commit

Permalink
Fix loss reporting (facebookresearch#453)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#453

When computing average loss for train/eval/test, loss across different batches should be weighted with batch size.

This doesn't matter for standard training where most batches are the full batch size except the last. However, when training in a massively distributed fashion (eg: for Federated Learning), data is split into small sizes. Many batches are not 'complete' batches, so weighting by batch size is important.

Reviewed By: gardenia22

Differential Revision: D14791247

fbshipit-source-id: 57b1142638755868a208dad233070a78537c7f3e
  • Loading branch information
Kshitiz Malik authored and facebook-github-bot committed Apr 8, 2019
1 parent 3c7ecd4 commit fb58d20
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pytext/metric_reporters/metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List

import numpy as np
import torch
from pytext.config.component import Component, ComponentType
from pytext.config.pytext_config import ConfigBase
Expand Down Expand Up @@ -74,7 +75,9 @@ def add_batch_stats(
if key not in self.all_context:
self.all_context[key] = []
self.aggregate_data(self.all_context[key], val)
self.all_loss.append(loss)
# some loss functions (eg: in NewBertRegressionTask) return a tensor
# convert tensor to float
self.all_loss.append(float(loss))
self.batch_size.append(len(m_input[0]))

def aggregate_preds(self, new_batch):
Expand Down Expand Up @@ -122,7 +125,7 @@ def calculate_loss(self):
"""
Calculate the average loss for all aggregated batch
"""
return sum(self.all_loss) / float(len(self.all_loss))
return np.average(self.all_loss, weights=self.batch_size)

def calculate_metric(self):
"""
Expand Down

0 comments on commit fb58d20

Please sign in to comment.