diff --git a/dnn/torch/osce/engine/bwe_engine.py b/dnn/torch/osce/engine/bwe_engine.py index 838cf368c..9ea4c0b73 100644 --- a/dnn/torch/osce/engine/bwe_engine.py +++ b/dnn/torch/osce/engine/bwe_engine.py @@ -85,6 +85,7 @@ def evaluate(model, criterion, dataloader, device, preemph_gamma=0, log_interval batch[key] = batch[key].to(device) target = batch['x_48'] + x_up = model.upsampler(batch['x_16'].unsqueeze(1)) # calculate model output output = model(batch['x_16'].unsqueeze(1), batch['features']) @@ -95,7 +96,7 @@ def evaluate(model, criterion, dataloader, device, preemph_gamma=0, log_interval output = preemph(output, preemph_gamma) # calculate loss - loss = criterion(target, output.squeeze(1), model.upsampler(batch['x_16'].unsqueeze(1))) + loss = criterion(target, output.squeeze(1), x_up) # update running loss running_loss += float(loss.cpu()) diff --git a/dnn/torch/osce/train_bwe_model.py b/dnn/torch/osce/train_bwe_model.py index 1ec4a2c1e..b2c49ba4f 100644 --- a/dnn/torch/osce/train_bwe_model.py +++ b/dnn/torch/osce/train_bwe_model.py @@ -263,7 +263,7 @@ def criterion(x, y, x_up): print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") if hasattr(model, 'flop_count'): - print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") + print(f"{model.flop_count(48000) / 1e6:5.3f} MFLOPS") best_loss = 1e9