forked from ictnlp/DST
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder.py
96 lines (77 loc) · 3.44 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from src.utils import init_weights
from src.causal_conv import CausalConv
from src.causal_resblock import ResBlock
LRELU_SLOPE = 0.1
class Encoder(torch.nn.Module):
def __init__(self, h):
super(Encoder, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_downsamples = len(h.downsample_rates)
self.up_buf_len = 1
resblock = ResBlock
self.conv_pre = CausalConv(1, h.downsample_initial_channel, 7, 1)
self.downs = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.downsample_rates, h.downsample_kernel_sizes)):
self.downs.append(
CausalConv(h.downsample_initial_channel * (2 ** i),
h.downsample_initial_channel * (2 ** (i + 1)),
k, u))
self.resblocks = nn.ModuleList()
for i in range(len(self.downs)):
ch = h.downsample_initial_channel * (2 ** i)
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, 0, d))
self.conv_post = CausalConv(ch * 2, h.upsample_initial_channel, 7, 1)
def init_buffers(self, batch_size, device):
res_buf = []
for i in range(self.num_downsamples):
for j in range(self.num_kernels):
ctx_buf = self.resblocks[i * self.num_kernels + j].init_ctx_buf(batch_size, device)
res_buf.append(ctx_buf)
down_buf = []
for i, (u, k) in enumerate(zip(self.h.downsample_rates, self.h.downsample_kernel_sizes)):
ctx_buf = self.downs[i].init_ctx_buf(batch_size, device)
down_buf.append(ctx_buf)
pre_conv_buf = self.conv_pre.init_ctx_buf(batch_size, device)
post_conv_buf = self.conv_post.init_ctx_buf(batch_size, device)
buffers = pre_conv_buf, res_buf, down_buf, post_conv_buf
return buffers
def forward(self, x, buffers):
pre_conv_buf, res_buf, down_buf, post_conv_buf = buffers
#pre conv buff
x, pre_conv_buf = self.conv_pre(x, pre_conv_buf)
for i in range(self.num_downsamples):
xs = None
for j in range(self.num_kernels):
ctx_buf = res_buf[i * self.num_kernels + j]
if xs is None:
xs, ctx_buf = self.resblocks[i * self.num_kernels + j](x, ctx_buf)
else:
xs_, ctx_buf = self.resblocks[i * self.num_kernels + j](x, ctx_buf)
xs += xs_
res_buf[i * self.num_kernels + j] = ctx_buf
x = xs / self.num_kernels
#ctx buffer
ctx_buf = down_buf[i]
x = F.leaky_relu(x, LRELU_SLOPE)
x, ctx_buf = self.downs[i](x, ctx_buf)
down_buf[i] = ctx_buf
x = F.leaky_relu(x)
#post conv buff
x, post_conv_buf = self.conv_post(x, post_conv_buf)
x = torch.tanh(x)
buffers = pre_conv_buf, res_buf, down_buf, post_conv_buf
return x, buffers
def remove_weight_norm(self):
for l in self.downs:
l.remove_weight_norm()
for l in self.resblocks:
l.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()