Skip to content

Commit 8701581

Browse files
feat: add loudness transform, alltransform, refactor into folders
1 parent 117d96f commit 8701581

File tree

15 files changed

+178
-90
lines changed

15 files changed

+178
-90
lines changed

audio_data_pytorch/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
from .ljspeech_dataset import LJSpeechDataset
2-
from .transforms import Crop, OverlapChannels, RandomCrop, Resample, Scale
3-
from .wav_dataset import WAVDataset
4-
from .youtube_dataset import YoutubeDataset
1+
from .datasets import * # noqa
2+
from .transforms import * # noqa
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ljspeech_dataset import LJSpeechDataset
2+
from .wav_dataset import WAVDataset
3+
from .youtube_dataset import YoutubeDataset

audio_data_pytorch/ljspeech_dataset.py renamed to audio_data_pytorch/datasets/ljspeech_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests # type: ignore
55
from tqdm import tqdm
66

7-
from .utils import camel_to_snake
7+
from ..utils import camel_to_snake
88
from .wav_dataset import WAVDataset
99

1010

audio_data_pytorch/youtube_dataset.py renamed to audio_data_pytorch/datasets/youtube_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.nn import functional as F
99
from tqdm import tqdm
1010

11-
from .utils import camel_to_snake, exists
11+
from ..utils import camel_to_snake, exists
1212
from .wav_dataset import WAVDataset
1313

1414

audio_data_pytorch/transforms.py

Lines changed: 0 additions & 82 deletions
This file was deleted.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .all import AllTransform
2+
from .crop import Crop
3+
from .loudness import Loudness
4+
from .overlap_channels import OverlapChannels
5+
from .randomcrop import RandomCrop
6+
from .resample import Resample
7+
from .scale import Scale

audio_data_pytorch/transforms/all.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Optional
2+
3+
from torch import Tensor, nn
4+
5+
from ..utils import exists
6+
from .crop import Crop
7+
from .loudness import Loudness
8+
from .overlap_channels import OverlapChannels
9+
from .randomcrop import RandomCrop
10+
from .resample import Resample
11+
from .scale import Scale
12+
13+
14+
class AllTransform(nn.Module):
15+
def __init__(
16+
self,
17+
source_rate: Optional[int] = None,
18+
target_rate: Optional[int] = None,
19+
crop_size: Optional[int] = None,
20+
random_crop_size: Optional[int] = None,
21+
loudness: Optional[int] = None,
22+
scale: Optional[float] = None,
23+
overlap_channels: bool = False,
24+
):
25+
super().__init__()
26+
27+
message = "Both source_rate and target_rate must be provided"
28+
assert not exists(source_rate) ^ exists(target_rate), message
29+
30+
message = "Loudness requires target_rate"
31+
assert not exists(loudness) or exists(target_rate), message
32+
33+
self.transform = nn.Sequential(
34+
Resample(source=source_rate, target=target_rate) # type: ignore
35+
if exists(source_rate) and source_rate != target_rate
36+
else nn.Identity(),
37+
RandomCrop(random_crop_size) if exists(random_crop_size) else nn.Identity(),
38+
Crop(crop_size) if exists(crop_size) else nn.Identity(),
39+
OverlapChannels() if overlap_channels else nn.Identity(),
40+
Loudness(sampling_rate=target_rate, target=loudness) # type: ignore
41+
if exists(loudness)
42+
else nn.Identity(),
43+
Scale(scale) if exists(scale) else nn.Identity(),
44+
)
45+
46+
def forward(self, x: Tensor) -> Tensor:
47+
return self.transform(x)

audio_data_pytorch/transforms/crop.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from torch import Tensor, nn
3+
4+
5+
class Crop(nn.Module):
6+
"""Crops waveform to fixed size"""
7+
8+
def __init__(self, size: int, start: int = 0) -> None:
9+
super().__init__()
10+
self.size = size
11+
self.start = start
12+
13+
def forward(self, x: Tensor) -> Tensor:
14+
x = x[:, self.start :]
15+
channels, length = x.shape
16+
17+
if length < self.size:
18+
padding_length = self.size - length
19+
padding = torch.zeros(channels, padding_length).to(x)
20+
return torch.cat([x, padding], dim=1)
21+
else:
22+
return x[:, 0 : self.size]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pyloudnorm as pyln
2+
import torch
3+
from torch import Tensor, nn
4+
5+
6+
class Loudness(nn.Module):
7+
"""Normalizes to target loudness using BS.1770-4, requires pyloudnorm"""
8+
9+
def __init__(self, sampling_rate: int, target: float):
10+
super().__init__()
11+
self.sampling_rate = sampling_rate
12+
self.target = target
13+
self.meter = pyln.Meter(sampling_rate)
14+
15+
def forward(self, x: Tensor) -> Tensor:
16+
channels, length = x.shape
17+
# Measure sample loudness
18+
x_numpy = x.numpy().T
19+
loudness = self.meter.integrated_loudness(data=x_numpy)
20+
# Don't normalize zeros sample (i.e. silence)
21+
if loudness == -float("inf"):
22+
return x
23+
# Normalize sample loudness
24+
x_normalized = pyln.normalize.loudness(
25+
data=x_numpy, input_loudness=loudness, target_loudness=self.target
26+
)
27+
# Return normalized as torch Tensor
28+
return torch.from_numpy(x_normalized.T)

0 commit comments

Comments
 (0)