|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Created on 2018/12/18 |
| 4 | +# Author: Kaituo XU |
| 5 | + |
| 6 | +import argparse |
| 7 | +import os |
| 8 | + |
| 9 | +import librosa |
| 10 | +from mir_eval.separation import bss_eval_sources |
| 11 | +import numpy as np |
| 12 | +import torch |
| 13 | + |
| 14 | +from data import AudioDataLoader, AudioDataset |
| 15 | +from pit_criterion import cal_loss |
| 16 | +from tasnet import TasNet |
| 17 | + |
| 18 | + |
| 19 | +parser = argparse.ArgumentParser('Evaluate separation performance using TasNet') |
| 20 | +parser.add_argument('--model_path', type=str, required=True, |
| 21 | + help='Path to model file created by training') |
| 22 | +parser.add_argument('--data_dir', type=str, required=True, |
| 23 | + help='directory including mix.json, s1.json and s2.json') |
| 24 | +parser.add_argument('--use_cuda', type=int, default=0, |
| 25 | + help='Whether use GPU') |
| 26 | +parser.add_argument('--sample_rate', default=8000, type=int, |
| 27 | + help='Sample rate') |
| 28 | +parser.add_argument('--batch_size', default=1, type=int, |
| 29 | + help='Batch size') |
| 30 | + |
| 31 | + |
| 32 | +def evaluate(args): |
| 33 | + total_sisnr = 0 |
| 34 | + total_sdr = 0 |
| 35 | + total_cnt = 0 |
| 36 | + # Load model |
| 37 | + model = TasNet.load_model(args.model_path) |
| 38 | + print(model) |
| 39 | + model.eval() |
| 40 | + if args.use_cuda: |
| 41 | + model.cuda() |
| 42 | + |
| 43 | + # Load data |
| 44 | + dataset = AudioDataset(args.data_dir, args.batch_size, |
| 45 | + sample_rate=args.sample_rate, L=model.L) |
| 46 | + data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2) |
| 47 | + |
| 48 | + with torch.no_grad(): |
| 49 | + for i, (data) in enumerate(data_loader): |
| 50 | + # Get batch data |
| 51 | + padded_mixture, mixture_lengths, padded_source = data |
| 52 | + if args.use_cuda: |
| 53 | + padded_mixture = padded_mixture.cuda() |
| 54 | + mixture_lengths = mixture_lengths.cuda() |
| 55 | + padded_source = padded_source.cuda() |
| 56 | + # Forward |
| 57 | + estimate_source = model(padded_mixture, mixture_lengths) # [B, C, K, L] |
| 58 | + loss, max_snr, estimate_source, reorder_estimate_source = \ |
| 59 | + cal_loss(padded_source, estimate_source, mixture_lengths) |
| 60 | + # Remove padding and flat |
| 61 | + mixture = remove_pad_and_flat(padded_mixture, mixture_lengths) |
| 62 | + source = remove_pad_and_flat(padded_source, mixture_lengths) |
| 63 | + estimate_source = remove_pad_and_flat(estimate_source, mixture_lengths) |
| 64 | + # for each utterance |
| 65 | + for mix, src_ref, src_est in zip(mixture, source, estimate_source): |
| 66 | + # src_ref = np.stack([s1, s2], axis=0) |
| 67 | + # src_est = np.stack([recon_s1_sig, recon_s2_sig], axis=0) |
| 68 | + src_anchor = np.stack([mix, mix], axis=0) |
| 69 | + sisnr1 = get_SISNR(src_ref[0], src_est[0]) |
| 70 | + sisnr2 = get_SISNR(src_ref[1], src_est[1]) |
| 71 | + sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est) |
| 72 | + sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor) |
| 73 | + # sisnr1 = get_SISNR(s1, recon_s1_sig) |
| 74 | + # sisnr2 = get_SISNR(s2, recon_s2_sig) |
| 75 | + print("sisnr1: {0:.2f}, sisnr2: {1:.2f}".format(sisnr1, sisnr2)) |
| 76 | + print("sdr1: {0:.2f}, sdr2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[0])) |
| 77 | + |
| 78 | + total_sisnr += sisnr1 + sisnr2 |
| 79 | + total_sdr += (sdr[0]-sdr0[0]) + (sdr[1]-sdr0[0]) |
| 80 | + total_cnt += 2 |
| 81 | + print("Average sisnr improvement: {0:.2f}".format(total_sisnr / total_cnt)) |
| 82 | + print("Average sdr improvement: {0:.2f}".format(total_sdr / total_cnt)) |
| 83 | + |
| 84 | + |
| 85 | +def remove_pad_and_flat(inputs, inputs_lengths): |
| 86 | + """ |
| 87 | + Args: |
| 88 | + inputs: torch.Tensor, [B, C, K, L] or [B, K, L] |
| 89 | + inputs_lengths: torch.Tensor, [B] |
| 90 | + Returns: |
| 91 | + results: a list containing B items, each item is [C, T], T varies |
| 92 | + """ |
| 93 | + results = [] |
| 94 | + dim = inputs.dim() |
| 95 | + if dim == 4: |
| 96 | + C = inputs.size(1) |
| 97 | + for input, length in zip(inputs, inputs_lengths): |
| 98 | + if dim == 4: # [B, C, K, L] |
| 99 | + results.append(input[:,:length].view(C, -1).cpu().numpy()) |
| 100 | + elif dim == 3: # [B, K, L] |
| 101 | + results.append(input[:length].view(-1).cpu().numpy()) |
| 102 | + return results |
| 103 | + |
| 104 | + |
| 105 | +def get_SISNR(ref_sig, out_sig, eps=1e-8): |
| 106 | + assert len(ref_sig) == len(out_sig) |
| 107 | + ref_sig = ref_sig - np.mean(ref_sig) |
| 108 | + out_sig = out_sig - np.mean(out_sig) |
| 109 | + ref_energy = np.sum(ref_sig ** 2) + eps |
| 110 | + proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy |
| 111 | + noise = out_sig - proj |
| 112 | + ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps) |
| 113 | + sisnr = 10 * np.log(ratio + eps) / np.log(10.0) |
| 114 | + return sisnr |
| 115 | + |
| 116 | + |
| 117 | +if __name__ == '__main__': |
| 118 | + args = parser.parse_args() |
| 119 | + print(args) |
| 120 | + evaluate(args) |
0 commit comments