@@ -1578,15 +1578,35 @@ def loudness(waveform: Tensor, sample_rate: int):
15781578    gated_blocks  =  loudness  >  gamma_abs 
15791579    gated_blocks  =  gated_blocks .unsqueeze (- 2 )
15801580
1581-     energy_filtered  =  torch .sum (gated_blocks  *  energy , dim = - 1 ) /  torch .count_nonzero (gated_blocks , dim = - 1 )
1581+     # Compute numerator and denominator 
1582+     sum_gated_energy  =  torch .sum (gated_blocks  *  energy , dim = - 1 )
1583+     count_gated_blocks  =  torch .count_nonzero (gated_blocks , dim = - 1 )
1584+ 
1585+     # Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0 
1586+     energy_filtered  =  torch .where (
1587+         count_gated_blocks  >  0 ,
1588+         sum_gated_energy  /  count_gated_blocks ,
1589+         torch .tensor (0.0 , dtype = sum_gated_energy .dtype , device = sum_gated_energy .device )
1590+     )
1591+ 
15821592    energy_weighted  =  torch .sum (g  *  energy_filtered , dim = - 1 )
15831593    gamma_rel  =  kweight_bias  +  10  *  torch .log10 (energy_weighted ) -  10 
15841594
15851595    # Apply relative gating of the blocks 
15861596    gated_blocks  =  torch .logical_and (gated_blocks .squeeze (- 2 ), loudness  >  gamma_rel .unsqueeze (- 1 ))
15871597    gated_blocks  =  gated_blocks .unsqueeze (- 2 )
15881598
1589-     energy_filtered  =  torch .sum (gated_blocks  *  energy , dim = - 1 ) /  torch .count_nonzero (gated_blocks , dim = - 1 )
1599+     # Compute numerator and denominator 
1600+     sum_gated_energy  =  torch .sum (gated_blocks  *  energy , dim = - 1 )
1601+     count_gated_blocks  =  torch .count_nonzero (gated_blocks , dim = - 1 )
1602+ 
1603+     # Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0 
1604+     energy_filtered  =  torch .where (
1605+         count_gated_blocks  >  0 ,
1606+         sum_gated_energy  /  count_gated_blocks ,
1607+         torch .tensor (0.0 , dtype = sum_gated_energy .dtype , device = sum_gated_energy .device )
1608+     )
1609+ 
15901610    energy_weighted  =  torch .sum (g  *  energy_filtered , dim = - 1 )
15911611    LKFS  =  kweight_bias  +  10  *  torch .log10 (energy_weighted )
15921612    return  LKFS 
0 commit comments