Skip to content

Commit

Permalink
fix wav2vec2 report loss bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Zth9730 committed Oct 16, 2022
1 parent 49c0cf9 commit 86f65f0
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions paddlespeech/s2t/exps/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Contains wav2vec2 model."""
import json
import math
import os
import time
from collections import defaultdict
Expand Down Expand Up @@ -46,25 +47,20 @@
class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.avg_train_loss = 0
self.avg_train_loss = 0.0

def update_average(self, batch_index, loss, avg_loss):
def update_average(self, batch_index, loss):
"""Update running average of the loss.
Arguments
---------
batch_index : int
current batch index
loss : paddle.tensor
detached loss, a single float value.
avg_loss : float
current running average.
Returns
-------
avg_loss : float
The average loss.
"""
if paddle.isfinite(loss):
avg_loss -= avg_loss / (batch_index + 1)
avg_loss += float(loss) / (batch_index + 1)
return avg_loss
if math.isfinite(loss):
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
self.avg_train_loss += loss / (batch_index + 1)

def train_batch(self, batch_index, batch, msg):
train_conf = self.config
Expand All @@ -80,8 +76,8 @@ def train_batch(self, batch_index, batch, msg):
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad

self.avg_train_loss = self.update_average(batch_index, loss,
self.avg_train_loss)
# update self.avg_train_loss
self.update_average(batch_index, float(loss))

# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
Expand All @@ -106,7 +102,7 @@ def train_batch(self, batch_index, batch, msg):
self.lr_scheduler.step()
self.iteration += 1

losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad}
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)
Expand Down

0 comments on commit 86f65f0

Please sign in to comment.