forked from ryanleary/mlperf-rnnt-ref
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocessing.py
123 lines (93 loc) · 4.13 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import torch.nn as nn
from apex import amp
from helpers import Optimization
from parts.features import FeatureFactory
class SpecCutoutRegions(nn.Module):
"""Cutout. refer to https://arxiv.org/pdf/1708.04552.pdf
"""
def __init__(self, cfg):
super(SpecCutoutRegions, self).__init__()
self.cutout_rect_regions = cfg.get('cutout_rect_regions', 0)
self.cutout_rect_time = cfg.get('cutout_rect_time', 5)
self.cutout_rect_freq = cfg.get('cutout_rect_freq', 20)
@torch.no_grad()
def forward(self, x):
sh = x.shape
mask = torch.zeros(x.shape).bool()
for idx in range(sh[0]):
for i in range(self.cutout_rect_regions):
cutout_rect_x = int(random.uniform(
0, sh[1] - self.cutout_rect_freq))
cutout_rect_y = int(random.uniform(
0, sh[2] - self.cutout_rect_time))
mask[idx, cutout_rect_x:cutout_rect_x + self.cutout_rect_freq,
cutout_rect_y:cutout_rect_y + self.cutout_rect_time] = 1
x = x.masked_fill(mask.to(device=x.device), 0)
return x
class SpecAugment(nn.Module):
"""Spec augment. refer to https://arxiv.org/abs/1904.08779
"""
def __init__(self, cfg):
super(SpecAugment, self).__init__()
self.cutout_x_regions = cfg.get('cutout_x_regions', 0)
self.cutout_y_regions = cfg.get('cutout_y_regions', 0)
self.cutout_x_width = cfg.get('cutout_x_width', 10)
self.cutout_y_width = cfg.get('cutout_y_width', 10)
@torch.no_grad()
def forward(self, x):
sh = x.shape
mask = torch.zeros(x.shape).bool()
for idx in range(sh[0]):
for _ in range(self.cutout_x_regions):
cutout_x_left = int(random.uniform(0, sh[1] - self.cutout_x_width))
mask[idx, cutout_x_left:cutout_x_left + self.cutout_x_width, :] = 1
for _ in range(self.cutout_y_regions):
cutout_y_left = int(random.uniform(0, sh[2] - self.cutout_y_width))
mask[idx, :, cutout_y_left:cutout_y_left + self.cutout_y_width] = 1
x = x.masked_fill(mask.to(device=x.device), 0)
return x
class SpectrogramAugmentation(nn.Module):
"""Spectrogram augmentation
"""
def __init__(self, **kwargs):
nn.Module.__init__(self)
self.spec_cutout_regions = SpecCutoutRegions(kwargs)
self.spec_augment = SpecAugment(kwargs)
@torch.no_grad()
def forward(self, input_spec):
augmented_spec = self.spec_cutout_regions(input_spec)
augmented_spec = self.spec_augment(augmented_spec)
return augmented_spec
class AudioPreprocessing(nn.Module):
"""GPU accelerated audio preprocessing
"""
def __init__(self, **kwargs):
nn.Module.__init__(self) # For PyTorch API
self.optim_level = kwargs.get('optimization_level', Optimization.nothing)
self.featurizer = FeatureFactory.from_config(kwargs)
def forward(self, x):
input_signal, length = x
length.requires_grad_(False)
if self.optim_level not in [Optimization.nothing, Optimization.mxprO0, Optimization.mxprO3]:
with amp.disable_casts():
processed_signal = self.featurizer(x)
processed_length = self.featurizer.get_seq_len(length)
else:
processed_signal = self.featurizer(x)
processed_length = self.featurizer.get_seq_len(length)
return processed_signal, processed_length