-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathmodels.py
293 lines (235 loc) · 13.5 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import sys
class SingleRNN(nn.Module):
"""
Container module for a single RNN layer.
args:
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
dropout: float, dropout ratio. Default is 0.
bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
"""
def __init__(self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False):
super(SingleRNN, self).__init__()
self.rnn_type = rnn_type
self.input_size = input_size
self.hidden_size = hidden_size
self.num_direction = int(bidirectional) + 1
self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, 1, dropout=dropout, batch_first=True, bidirectional=bidirectional)
# linear projection layer
self.proj = nn.Linear(hidden_size*self.num_direction, input_size)
def forward(self, input):
# input shape: batch, seq, dim
output = input
rnn_output, _ = self.rnn(output)
rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(output.shape)
return rnn_output
# dual-path RNN
class DPRNN(nn.Module):
"""
Deep duaL-path RNN.
args:
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
output_size: int, dimension of the output size.
dropout: float, dropout ratio. Default is 0.
num_layers: int, number of stacked RNN layers. Default is 1.
bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
"""
def __init__(self, rnn_type, input_size, hidden_size, output_size,
dropout=0, num_layers=1, bidirectional=True):
super(DPRNN, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
# dual-path RNN
self.row_rnn = nn.ModuleList([])
self.col_rnn = nn.ModuleList([])
self.row_norm = nn.ModuleList([])
self.col_norm = nn.ModuleList([])
for i in range(num_layers):
self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=True)) # intra-segment RNN is always noncausal
self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional))
self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
# default is to use noncausal LayerNorm for inter-chunk RNN. For causal setting change it to causal normalization techniques accordingly.
self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
# output layer
self.output = nn.Sequential(nn.PReLU(),
nn.Conv2d(input_size, output_size, 1)
)
def forward(self, input):
# input shape: batch, N, dim1, dim2
# apply RNN on dim1 first and then dim2
batch_size, _, dim1, dim2 = input.shape
output = input
for i in range(len(self.row_rnn)):
row_input = output.permute(0,3,2,1).contiguous().view(batch_size*dim2, dim1, -1) # B*dim2, dim1, N
row_output = self.row_rnn[i](row_input) # B*dim2, dim1, H
row_output = row_output.view(batch_size, dim2, dim1, -1).permute(0,3,2,1).contiguous() # B, N, dim1, dim2
row_output = self.row_norm[i](row_output)
output = output + row_output
col_input = output.permute(0,2,3,1).contiguous().view(batch_size*dim1, dim2, -1) # B*dim1, dim2, N
col_output = self.col_rnn[i](col_input) # B*dim1, dim2, H
col_output = col_output.view(batch_size, dim1, dim2, -1).permute(0,3,1,2).contiguous() # B, N, dim1, dim2
col_output = self.col_norm[i](col_output)
output = output + col_output
output = self.output(output)
return output
# dual-path RNN with transform-average-concatenate (TAC)
class DPRNN_TAC(nn.Module):
"""
Deep duaL-path RNN with transform-average-concatenate (TAC) applied to each layer/block.
args:
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
output_size: int, dimension of the output size.
dropout: float, dropout ratio. Default is 0.
num_layers: int, number of stacked RNN layers. Default is 1.
bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
"""
def __init__(self, rnn_type, input_size, hidden_size, output_size,
dropout=0, num_layers=1, bidirectional=True):
super(DPRNN_TAC, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
# DPRNN + TAC for 3D input (ch, N, T)
self.row_rnn = nn.ModuleList([])
self.col_rnn = nn.ModuleList([])
self.ch_transform = nn.ModuleList([])
self.ch_average = nn.ModuleList([])
self.ch_concat = nn.ModuleList([])
self.row_norm = nn.ModuleList([])
self.col_norm = nn.ModuleList([])
self.ch_norm = nn.ModuleList([])
for i in range(num_layers):
self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=True)) # intra-segment RNN is always noncausal
self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional))
self.ch_transform.append(nn.Sequential(nn.Linear(input_size, hidden_size*3),
nn.PReLU()
)
)
self.ch_average.append(nn.Sequential(nn.Linear(hidden_size*3, hidden_size*3),
nn.PReLU()
)
)
self.ch_concat.append(nn.Sequential(nn.Linear(hidden_size*6, input_size),
nn.PReLU()
)
)
self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
# default is to use noncausal LayerNorm for inter-chunk RNN and TAC modules. For causal setting change them to causal normalization techniques accordingly.
self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
self.ch_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
# output layer
self.output = nn.Sequential(nn.PReLU(),
nn.Conv2d(input_size, output_size, 1)
)
def forward(self, input, num_mic):
# input shape: batch, ch, N, dim1, dim2
# num_mic shape: batch,
# apply RNN on dim1 first, then dim2, then ch
batch_size, ch, N, dim1, dim2 = input.shape
output = input
for i in range(len(self.row_rnn)):
# intra-segment RNN
output = output.view(batch_size*ch, N, dim1, dim2) # B*ch, N, dim1, dim2
row_input = output.permute(0,3,2,1).contiguous().view(batch_size*ch*dim2, dim1, -1) # B*ch*dim2, dim1, N
row_output = self.row_rnn[i](row_input) # B*ch*dim2, dim1, N
row_output = row_output.view(batch_size*ch, dim2, dim1, -1).permute(0,3,2,1).contiguous() # B*ch, N, dim1, dim2
row_output = self.row_norm[i](row_output)
output = output + row_output # B*ch, N, dim1, dim2
# inter-segment RNN
col_input = output.permute(0,2,3,1).contiguous().view(batch_size*ch*dim1, dim2, -1) # B*ch*dim1, dim2, N
col_output = self.col_rnn[i](col_input) # B*dim1, dim2, N
col_output = col_output.view(batch_size*ch, dim1, dim2, -1).permute(0,3,1,2).contiguous() # B*ch, N, dim1, dim2
col_output = self.col_norm[i](col_output)
output = output + col_output # B*ch, N, dim1, dim2
# TAC for cross-channel communication
ch_input = output.view(input.shape) # B, ch, N, dim1, dim2
ch_input = ch_input.permute(0,3,4,1,2).contiguous().view(-1, N) # B*dim1*dim2*ch, N
ch_output = self.ch_transform[i](ch_input).view(batch_size, dim1*dim2, ch, -1) # B, dim1*dim2, ch, H
# mean pooling across channels
if num_mic.max() == 0:
# fixed geometry array
ch_mean = ch_output.mean(2).view(batch_size*dim1*dim2, -1) # B*dim1*dim2, H
else:
# only consider valid channels
ch_mean = [ch_output[b,:,:num_mic[b]].mean(1).unsqueeze(0) for b in range(batch_size)] # 1, dim1*dim2, H
ch_mean = torch.cat(ch_mean, 0).view(batch_size*dim1*dim2, -1) # B*dim1*dim2, H
ch_output = ch_output.view(batch_size*dim1*dim2, ch, -1) # B*dim1*dim2, ch, H
ch_mean = self.ch_average[i](ch_mean).unsqueeze(1).expand_as(ch_output).contiguous() # B*dim1*dim2, ch, H
ch_output = torch.cat([ch_output, ch_mean], 2) # B*dim1*dim2, ch, 2H
ch_output = self.ch_concat[i](ch_output.view(-1, ch_output.shape[-1])) # B*dim1*dim2*ch, N
ch_output = ch_output.view(batch_size, dim1, dim2, ch, -1).permute(0,3,4,1,2).contiguous() # B, ch, N, dim1, dim2
ch_output = self.ch_norm[i](ch_output.view(batch_size*ch, N, dim1, dim2)) # B*ch, N, dim1, dim2
output = output + ch_output
output = self.output(output) # B*ch, N, dim1, dim2
return output
# base module for deep DPRNN
class DPRNN_base(nn.Module):
def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
layer=4, segment_size=100, bidirectional=True, model_type='DPRNN',
rnn_type='LSTM'):
super(DPRNN_base, self).__init__()
assert model_type in ['DPRNN', 'DPRNN_TAC'], "model_type can only be 'DPRNN' or 'DPRNN_TAC'."
self.input_dim = input_dim
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.layer = layer
self.segment_size = segment_size
self.num_spk = num_spk
self.model_type = model_type
self.eps = 1e-8
# bottleneck
self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False)
# DPRNN model
self.DPRNN = getattr(sys.modules[__name__], model_type)(rnn_type, self.feature_dim, self.hidden_dim, self.feature_dim*self.num_spk,
num_layers=layer, bidirectional=bidirectional)
def pad_segment(self, input, segment_size):
# input is the features: (B, N, T)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
if rest > 0:
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def split_feature(self, input, segment_size):
# split the feature into chunks of segment size
# input is the features: (B, N, T)
input, rest = self.pad_segment(input, segment_size)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
segments1 = input[:,:,:-segment_stride].contiguous().view(batch_size, dim, -1, segment_size)
segments2 = input[:,:,segment_stride:].contiguous().view(batch_size, dim, -1, segment_size)
segments = torch.cat([segments1, segments2], 3).view(batch_size, dim, -1, segment_size).transpose(2, 3)
return segments.contiguous(), rest
def merge_feature(self, input, rest):
# merge the splitted features into full utterance
# input is the features: (B, N, L, K)
batch_size, dim, segment_size, _ = input.shape
segment_stride = segment_size // 2
input = input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size*2) # B, N, K, L
input1 = input[:,:,:,:segment_size].contiguous().view(batch_size, dim, -1)[:,:,segment_stride:]
input2 = input[:,:,:,segment_size:].contiguous().view(batch_size, dim, -1)[:,:,:-segment_stride]
output = input1 + input2
if rest > 0:
output = output[:,:,:-rest]
return output.contiguous() # B, N, T
def forward(self, input):
pass