Skip to content

Commit

Permalink
Remove CLAP
Browse files Browse the repository at this point in the history
  • Loading branch information
OedoSoldier committed Dec 16, 2023
1 parent dcb808f commit 62fd59b
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 219 deletions.
15 changes: 1 addition & 14 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def __init__(self, audiopaths_sid_text, hparams):
self.min_text_len = getattr(hparams, "min_text_len", 1)
self.max_text_len = getattr(hparams, "max_text_len", 384)

self.variance = 0.015

random.seed(1234)
random.shuffle(self.audiopaths_sid_text)
self._filter()
Expand Down Expand Up @@ -96,13 +94,7 @@ def get_audio_text_speaker_pair(self, audiopath_sid_text):
spec, wav = self.get_audio(audiopath)
sid = torch.LongTensor([int(self.spk_map[sid])])

emo = torch.squeeze(
torch.load(audiopath.replace(".wav", ".emo.pt"), map_location="cpu"),
dim=1,
)

emo += torch.randn_like(emo) * self.variance
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)

def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename)
Expand Down Expand Up @@ -223,7 +215,6 @@ def __call__(self, batch):
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
emo = torch.FloatTensor(len(batch), 512)

spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
Expand All @@ -235,7 +226,6 @@ def __call__(self, batch):
bert_padded.zero_()
ja_bert_padded.zero_()
en_bert_padded.zero_()
emo.zero_()

for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
Expand Down Expand Up @@ -269,8 +259,6 @@ def __call__(self, batch):
en_bert = row[8]
en_bert_padded[i, :, : en_bert.size(1)] = en_bert

emo[i, :] = row[9]

return (
text_padded,
text_lengths,
Expand All @@ -284,7 +272,6 @@ def __call__(self, batch):
bert_padded,
ja_bert_padded,
en_bert_padded,
emo,
)


Expand Down
71 changes: 19 additions & 52 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,6 @@
}


# def get_emo_(reference_audio, emotion, sid):
# emo = (
# torch.from_numpy(get_emo(reference_audio))
# if reference_audio and emotion == -1
# else torch.FloatTensor(
# np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
# )
# )
# return emo


def get_net_g(model_path: str, version: str, device: str, hps):
if version != latest_version:
net_g = SynthesizerTrnMap[version](
Expand Down Expand Up @@ -141,7 +130,6 @@ def get_text(text, language_str, hps, device):

def infer(
text,
emotion,
sdp_ratio,
noise_scale,
noise_scale_w,
Expand All @@ -151,7 +139,6 @@ def infer(
hps,
net_g,
device,
reference_audio=None,
skip_start=False,
skip_end=False,
):
Expand Down Expand Up @@ -180,23 +167,23 @@ def infer(
version = hps.version if hasattr(hps, "version") else latest_version
# 非当前版本,根据版本号选择合适的infer
if version != latest_version:
if version in inferMap_V3.keys():
return inferMap_V3[version](
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio,
emotion,
skip_start,
skip_end,
)
# if version in inferMap_V3.keys():
# return inferMap_V3[version](
# text,
# sdp_ratio,
# noise_scale,
# noise_scale_w,
# length_scale,
# sid,
# language,
# hps,
# net_g,
# device,
# reference_audio,
# emotion,
# skip_start,
# skip_end,
# )
if version in inferMap_V2.keys():
return inferMap_V2[version](
text,
Expand All @@ -222,14 +209,6 @@ def infer(
net_g,
device,
)
# 在此处实现当前版本的推理
# emo = get_emo_(reference_audio, emotion, sid)
if isinstance(reference_audio, np.ndarray):
emo = get_clap_audio_feature(reference_audio, device)
else:
emo = get_clap_text_feature(emotion, device)
emo = torch.squeeze(emo, dim=1)

bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
text, language, hps, device
)
Expand All @@ -255,7 +234,6 @@ def infer(
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
emo = emo.to(device).unsqueeze(0)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
audio = (
Expand All @@ -268,7 +246,6 @@ def infer(
bert,
ja_bert,
en_bert,
emo,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
Expand All @@ -278,7 +255,7 @@ def infer(
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio
Expand All @@ -295,18 +272,10 @@ def infer_multilang(
hps,
net_g,
device,
reference_audio=None,
emotion=None,
skip_start=False,
skip_end=False,
):
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
# emo = get_emo_(reference_audio, emotion, sid)
if isinstance(reference_audio, np.ndarray):
emo = get_clap_audio_feature(reference_audio, device)
else:
emo = get_clap_text_feature(emotion, device)
emo = torch.squeeze(emo, dim=1)
for idx, (txt, lang) in enumerate(zip(text, language)):
skip_start = (idx != 0) or (skip_start and idx == 0)
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
Expand Down Expand Up @@ -351,7 +320,6 @@ def infer_multilang(
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
emo = emo.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
Expand All @@ -365,7 +333,6 @@ def infer_multilang(
bert,
ja_bert,
en_bert,
emo,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
Expand All @@ -375,7 +342,7 @@ def infer_multilang(
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio
51 changes: 6 additions & 45 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def __init__(
n_layers,
kernel_size,
p_dropout,
n_speakers,
gin_channels=0,
):
super().__init__()
Expand All @@ -365,31 +364,6 @@ def __init__(
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
# self.emo_proj = nn.Linear(512, hidden_channels)
self.in_feature_net = nn.Sequential(
# input is assumed to an already normalized embedding
nn.Linear(512, 1028, bias=False),
nn.GELU(),
nn.LayerNorm(1028),
*[Block(1028, 512) for _ in range(1)],
nn.Linear(1028, 512, bias=False),
# normalize before passing to VQ?
# nn.GELU(),
# nn.LayerNorm(512),
)
self.emo_vq = VectorQuantize(
dim=512,
codebook_size=64,
codebook_dim=32,
commitment_weight=0.1,
decay=0.85,
heads=32,
kmeans_iters=20,
separate_codebook_per_head=True,
stochastic_sample_codes=True,
threshold_ema_dead_code=2,
)
self.out_feature_net = nn.Linear(512, hidden_channels)

self.encoder = attentions.Encoder(
hidden_channels,
Expand All @@ -402,26 +376,17 @@ def __init__(
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)

def forward(
self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
):
sid = sid.cpu()
def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
bert_emb = self.bert_proj(bert).transpose(1, 2)
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
emo_emb = self.in_feature_net(emo)
emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
loss_commit = loss_commit.mean()
emo_emb = self.out_feature_net(emo_emb)
# emo_emb = self.emo_proj(emo.unsqueeze(1))
x = (
self.emb(x)
+ self.tone_emb(tone)
+ self.language_emb(language)
+ bert_emb
+ ja_bert_emb
+ en_bert_emb
+ emo_emb
) * math.sqrt(
self.hidden_channels
) # [b, t, h]
Expand All @@ -434,7 +399,7 @@ def forward(
stats = self.proj(x) * x_mask

m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask, loss_commit
return x, m, logs, x_mask


class ResidualCouplingBlock(nn.Module):
Expand Down Expand Up @@ -916,7 +881,6 @@ def __init__(
n_layers,
kernel_size,
p_dropout,
self.n_speakers,
gin_channels=self.enc_gin_channels,
)
self.dec = Generator(
Expand Down Expand Up @@ -984,14 +948,13 @@ def forward(
bert,
ja_bert,
en_bert,
emo=None,
):
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
x, m_p, logs_p, x_mask = self.enc_p(
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
Expand Down Expand Up @@ -1060,7 +1023,6 @@ def forward(
(z, z_p, m_p, logs_p, m_q, logs_q),
(x, logw, logw_, logw_sdp),
g,
loss_commit,
)

def infer(
Expand All @@ -1073,7 +1035,6 @@ def infer(
bert,
ja_bert,
en_bert,
emo=None,
noise_scale=0.667,
length_scale=1,
noise_scale_w=0.8,
Expand All @@ -1087,8 +1048,8 @@ def infer(
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
x, m_p, logs_p, x_mask, _ = self.enc_p(
x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
x, m_p, logs_p, x_mask = self.enc_p(
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
)
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
sdp_ratio
Expand Down
Loading

0 comments on commit 62fd59b

Please sign in to comment.