Skip to content

Commit 35c7f31

Browse files
committed
[Fix] loudness: prevent NaN when all blocks are below absolute threshold
1 parent 87ff22e commit 35c7f31

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

src/torchaudio/functional/functional.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,15 +1578,35 @@ def loudness(waveform: Tensor, sample_rate: int):
15781578
gated_blocks = loudness > gamma_abs
15791579
gated_blocks = gated_blocks.unsqueeze(-2)
15801580

1581-
energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
1581+
# Compute numerator and denominator
1582+
sum_gated_energy = torch.sum(gated_blocks * energy, dim=-1)
1583+
count_gated_blocks = torch.count_nonzero(gated_blocks, dim=-1)
1584+
1585+
# Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0
1586+
energy_filtered = torch.where(
1587+
count_gated_blocks > 0,
1588+
sum_gated_energy / count_gated_blocks,
1589+
torch.tensor(0.0, dtype=sum_gated_energy.dtype, device=sum_gated_energy.device)
1590+
)
1591+
15821592
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
15831593
gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10
15841594

15851595
# Apply relative gating of the blocks
15861596
gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1))
15871597
gated_blocks = gated_blocks.unsqueeze(-2)
15881598

1589-
energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
1599+
# Compute numerator and denominator
1600+
sum_gated_energy = torch.sum(gated_blocks * energy, dim=-1)
1601+
count_gated_blocks = torch.count_nonzero(gated_blocks, dim=-1)
1602+
1603+
# Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0
1604+
energy_filtered = torch.where(
1605+
count_gated_blocks > 0,
1606+
sum_gated_energy / count_gated_blocks,
1607+
torch.tensor(0.0, dtype=sum_gated_energy.dtype, device=sum_gated_energy.device)
1608+
)
1609+
15901610
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
15911611
LKFS = kweight_bias + 10 * torch.log10(energy_weighted)
15921612
return LKFS

0 commit comments

Comments
 (0)