Skip to content

Commit 1e6885a

Browse files
committed
impl evaluate that compute SI-SNRi, SDRi
1 parent 771eb14 commit 1e6885a

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
librosa
2+
mir_eval
23
visdom

src/evaluate.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#!/usr/bin/env python
2+
3+
# Created on 2018/12/18
4+
# Author: Kaituo XU
5+
6+
import argparse
7+
import os
8+
9+
import librosa
10+
from mir_eval.separation import bss_eval_sources
11+
import numpy as np
12+
import torch
13+
14+
from data import AudioDataLoader, AudioDataset
15+
from pit_criterion import cal_loss
16+
from tasnet import TasNet
17+
18+
19+
parser = argparse.ArgumentParser('Evaluate separation performance using TasNet')
20+
parser.add_argument('--model_path', type=str, required=True,
21+
help='Path to model file created by training')
22+
parser.add_argument('--data_dir', type=str, required=True,
23+
help='directory including mix.json, s1.json and s2.json')
24+
parser.add_argument('--use_cuda', type=int, default=0,
25+
help='Whether use GPU')
26+
parser.add_argument('--sample_rate', default=8000, type=int,
27+
help='Sample rate')
28+
parser.add_argument('--batch_size', default=1, type=int,
29+
help='Batch size')
30+
31+
32+
def evaluate(args):
33+
total_sisnr = 0
34+
total_sdr = 0
35+
total_cnt = 0
36+
# Load model
37+
model = TasNet.load_model(args.model_path)
38+
print(model)
39+
model.eval()
40+
if args.use_cuda:
41+
model.cuda()
42+
43+
# Load data
44+
dataset = AudioDataset(args.data_dir, args.batch_size,
45+
sample_rate=args.sample_rate, L=model.L)
46+
data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)
47+
48+
with torch.no_grad():
49+
for i, (data) in enumerate(data_loader):
50+
# Get batch data
51+
padded_mixture, mixture_lengths, padded_source = data
52+
if args.use_cuda:
53+
padded_mixture = padded_mixture.cuda()
54+
mixture_lengths = mixture_lengths.cuda()
55+
padded_source = padded_source.cuda()
56+
# Forward
57+
estimate_source = model(padded_mixture, mixture_lengths) # [B, C, K, L]
58+
loss, max_snr, estimate_source, reorder_estimate_source = \
59+
cal_loss(padded_source, estimate_source, mixture_lengths)
60+
# Remove padding and flat
61+
mixture = remove_pad_and_flat(padded_mixture, mixture_lengths)
62+
source = remove_pad_and_flat(padded_source, mixture_lengths)
63+
estimate_source = remove_pad_and_flat(estimate_source, mixture_lengths)
64+
# for each utterance
65+
for mix, src_ref, src_est in zip(mixture, source, estimate_source):
66+
# src_ref = np.stack([s1, s2], axis=0)
67+
# src_est = np.stack([recon_s1_sig, recon_s2_sig], axis=0)
68+
src_anchor = np.stack([mix, mix], axis=0)
69+
sisnr1 = get_SISNR(src_ref[0], src_est[0])
70+
sisnr2 = get_SISNR(src_ref[1], src_est[1])
71+
sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
72+
sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
73+
# sisnr1 = get_SISNR(s1, recon_s1_sig)
74+
# sisnr2 = get_SISNR(s2, recon_s2_sig)
75+
print("sisnr1: {0:.2f}, sisnr2: {1:.2f}".format(sisnr1, sisnr2))
76+
print("sdr1: {0:.2f}, sdr2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[0]))
77+
78+
total_sisnr += sisnr1 + sisnr2
79+
total_sdr += (sdr[0]-sdr0[0]) + (sdr[1]-sdr0[0])
80+
total_cnt += 2
81+
print("Average sisnr improvement: {0:.2f}".format(total_sisnr / total_cnt))
82+
print("Average sdr improvement: {0:.2f}".format(total_sdr / total_cnt))
83+
84+
85+
def remove_pad_and_flat(inputs, inputs_lengths):
86+
"""
87+
Args:
88+
inputs: torch.Tensor, [B, C, K, L] or [B, K, L]
89+
inputs_lengths: torch.Tensor, [B]
90+
Returns:
91+
results: a list containing B items, each item is [C, T], T varies
92+
"""
93+
results = []
94+
dim = inputs.dim()
95+
if dim == 4:
96+
C = inputs.size(1)
97+
for input, length in zip(inputs, inputs_lengths):
98+
if dim == 4: # [B, C, K, L]
99+
results.append(input[:,:length].view(C, -1).cpu().numpy())
100+
elif dim == 3: # [B, K, L]
101+
results.append(input[:length].view(-1).cpu().numpy())
102+
return results
103+
104+
105+
def get_SISNR(ref_sig, out_sig, eps=1e-8):
106+
assert len(ref_sig) == len(out_sig)
107+
ref_sig = ref_sig - np.mean(ref_sig)
108+
out_sig = out_sig - np.mean(out_sig)
109+
ref_energy = np.sum(ref_sig ** 2) + eps
110+
proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
111+
noise = out_sig - proj
112+
ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
113+
sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
114+
return sisnr
115+
116+
117+
if __name__ == '__main__':
118+
args = parser.parse_args()
119+
print(args)
120+
evaluate(args)

0 commit comments

Comments
 (0)