Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed May 3, 2024
1 parent 1f4d67b commit 04dcb8a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion dnn/torch/osce/engine/bwe_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def evaluate(model, criterion, dataloader, device, preemph_gamma=0, log_interval
batch[key] = batch[key].to(device)

target = batch['x_48']
x_up = model.upsampler(batch['x_16'].unsqueeze(1))

# calculate model output
output = model(batch['x_16'].unsqueeze(1), batch['features'])
Expand All @@ -95,7 +96,7 @@ def evaluate(model, criterion, dataloader, device, preemph_gamma=0, log_interval
output = preemph(output, preemph_gamma)

# calculate loss
loss = criterion(target, output.squeeze(1), model.upsampler(batch['x_16'].unsqueeze(1)))
loss = criterion(target, output.squeeze(1), x_up)

# update running loss
running_loss += float(loss.cpu())
Expand Down
2 changes: 1 addition & 1 deletion dnn/torch/osce/train_bwe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def criterion(x, y, x_up):

print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"{model.flop_count(48000) / 1e6:5.3f} MFLOPS")


best_loss = 1e9
Expand Down

0 comments on commit 04dcb8a

Please sign in to comment.