From e2a47ddb7092d2959298738ff45dc427f0466b29 Mon Sep 17 00:00:00 2001 From: kamo-naoyuki Date: Mon, 11 Mar 2019 18:09:20 +0900 Subject: [PATCH] Implement: utils/eval-source-separation.py --- utils/eval-source-separation.py | 377 ++++++++++++++++++++++++++++++++ utils/eval_source_separation.sh | 72 ++++++ 2 files changed, 449 insertions(+) create mode 100755 utils/eval-source-separation.py create mode 100755 utils/eval_source_separation.sh diff --git a/utils/eval-source-separation.py b/utils/eval-source-separation.py new file mode 100755 index 00000000000..cb39c7489a9 --- /dev/null +++ b/utils/eval-source-separation.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import shutil +import tempfile +from collections import OrderedDict +from distutils.util import strtobool +import itertools +import os +import subprocess +import sys +import logging + +import museval +import numpy as np +import soundfile +from pystoi.stoi import stoi + +from espnet.utils.cli_utils import get_commandline_args + + +PY2 = sys.version_info[0] == 2 + + +def eval_STOI(ref, y, fs, extended=False, compute_permutation=True): + """Calculate STOI + + Reference: + A short-time objective intelligibility measure + for time-frequency weighted noisy speech + https://ieeexplore.ieee.org/document/5495701 + + Note(kamo): + STOI is defined on the signal at 10kHz + and the input at the other sampling rate will be resampled. + Thus, the result differs depending on the implementation of resampling. + Especially, pystoi cannot reproduce matlab's resampling now. + + :param ref (np.ndarray): Reference (Nsrc, Nframe, Nmic) + :param y (np.ndarray): Enhanced (Nsrc, Nframe, Nmic) + :param fs (int): Sample frequency + :param extended (bool): stoi or estoi + :param compute_permutation (bool): + :return: value, perm + :rtype: Tuple[Tuple[float, ...], Tuple[int, ...]] + """ + if ref.shape != y.shape: + raise ValueError('ref and y should have the same shape: {} != {}' + .format(ref.shape, y.shape)) + if ref.ndim != 3: + raise ValueError('Input must have 3 dims: {}'.format_map(ref.ndim)) + n_src = ref.shape[0] + n_mic = ref.shape[2] + + if compute_permutation: + index_list = list(itertools.permutations(range(n_src))) + else: + index_list = [list(range(n_src))] + + values = [[sum(stoi(ref[i, :, ch], y[j, :, ch], fs, extended) + for ch in range(n_mic)) / n_mic + for i, j in enumerate(indices)] + for indices in index_list] + + best_pairs = sorted([(v, i) for v, i in zip(values, index_list)], + key=lambda x: sum(x[0]))[-1] + value, perm = best_pairs + return tuple(value), tuple(perm) + + +class TemporaryDirectory(object): + """Ported from python3 tempflie.TemporaryDirectory""" + def __init__(self, suffix=None, prefix=None, dir=None): + self.name = tempfile.mkdtemp(suffix, prefix, dir) + + def __repr__(self): + return "<{} {!r}>".format(self.__class__.__name__, self.name) + + def __enter__(self): + return self.name + + def __exit__(self, exc, value, tb): + self.cleanup() + + def cleanup(self): + shutil.rmtree(self.name) + + +def eval_PESQ(ref, enh, fs, compute_permutation): + """Evaluate PESQ + + PESQ program can be downloaded from here: + https://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-P.862-200102-I!!SOFT-ZST-E&type=items + + Reference: + Perceptual evaluation of speech quality (PESQ)-a new method + for speech quality assessment of telephone networks and codecs + https://ieeexplore.ieee.org/document/941023 + + :param x (np.ndarray): Reference (Nsrc, Nframe, Nmic) + :param y (np.ndarray): Enhanced (Nsrc, Nframe, Nmic) + :param fs (int): Sample frequency + :param compute_permutation (bool): + """ + if PY2: + p = subprocess.Popen(['which', 'PESQ'], stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + _, _ = p.communicate() + if p.returncode != 0: + raise RuntimeError('PESQ: command not found: Please install') + else: + import shutil + if shutil.which('PESQ') is None: + raise RuntimeError('PESQ: command not found: Please install') + if fs not in (8000, 16000): + raise ValueError('Sample frequency must be 8000 or 16000: {}' + .format(fs)) + if ref.shape != enh.shape: + raise ValueError('ref and enh should have the same shape: {} != {}' + .format(ref.shape, enh.shape)) + if ref.ndim != 3: + raise ValueError('Input must have 3 dims: {}'.format_map(ref.ndim)) + + n_src = ref.shape[0] + n_mic = ref.shape[2] + with TemporaryDirectory() as d: + # TODO(kamo): Should we use python-binding for PESQ? + # Such as https://github.com/vBaiCai/python-pesq + # I'm not sure this approach is permitted as the licence agreement. + + # Dumping wav files temporary + ref_files = [] + enh_files = [] + for isrc in range(n_src): + refs = [] # [Nsrc, Nmic] + enhs = [] # [Nsrc, Nmic] + for imic in range(n_mic): + wv = str(os.path.join(d, 'ref.{}.{}.wav'.format(isrc, imic))) + soundfile.write(wv, ref[isrc, :, imic].astype(np.int16), fs) + refs.append(wv) + + wv = str(os.path.join(d, 'enh.{}.{}.wav'.format(isrc, imic))) + soundfile.write(wv, enh[isrc, :, imic].astype(np.int16), fs) + enhs.append(wv) + ref_files.append(refs) + enh_files.append(enhs) + + if compute_permutation: + index_list = list(itertools.permutations(range(n_src))) + else: + index_list = [list(range(n_src))] + + values = [] + for indices in index_list: + values2 = [] + for i, j in enumerate(indices): + lis = [] + for imic in range(n_mic): + commands = ['PESQ', '+{}'.format(fs), + ref_files[i][imic], enh_files[j][imic]] + with subprocess.Popen( + commands, stdout=subprocess.PIPE) as p: + stdout, _ = p.communicate() + + # Get the PESQ value from the stdout + last_line = stdout.decode().rstrip().split('\n')[-1] + if 'Prediction : PESQ_MOS = ' in last_line: + value = last_line.replace( + 'Prediction : PESQ_MOS = ', '') + lis.append(float(value)) + else: + raise RuntimeError( + 'Failed: {}\n{}'.format(' '.join(commands), + stdout.decode())) + # Averaging over n_mic + values2.append(sum(lis) / len(lis)) + values.append(values2) + best_pairs = sorted([(v, i) for v, i in zip(values, index_list)], + key=lambda x: sum(x[0]))[-1] + value, perm = best_pairs + return tuple(value), tuple(perm) + + +def main(): + parser = argparse.ArgumentParser( + description='Evaluate enhanced speech. ' + 'e.g. {c} --ref ref.scp --enh enh.scp --outdir outputdir' + 'or {c} --ref ref.scp ref2.scp --enh enh.scp enh2.scp ' + '--outdir outputdir' + .format(c=sys.argv[0]), + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--verbose', '-V', default=0, type=int, + help='Verbose option') + parser.add_argument('--ref', dest='reffiles', nargs='+', type=str, + required=True, + help='WAV file lists for reference') + parser.add_argument('--enh', dest='enhfiles', nargs='+', type=str, + required=True, + help='WAV files lists for enhanced') + parser.add_argument('--outdir', type=str, + required=True) + parser.add_argument('--keylist', type=str, + help='Specify the target samples. By default, ' + 'using all keys in the first reference file') + parser.add_argument('--evaltypes', type=str, nargs='+', + choices=['SDR', 'STOI', 'ESTOI', 'PESQ'], + default=['SDR', 'STOI', 'ESTOI', 'PESQ']) + parser.add_argument('--permutation', type=strtobool, default=True, + help='Compute all permutations or ' + 'use the pair of input order') + parser.add_argument('--source-image', type=strtobool, default=True, + help='Use this option in the case that ' + 'the reference is source image. ' + 'For more detail, see museval source codes.') + args = parser.parse_args() + + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + if len(args.reffiles) != len(args.enhfiles): + raise RuntimeError( + 'The number of ref files are different ' + 'from the enh files: {} != {}'.format(len(args.reffiles), + len(args.enhfiles))) + if len(args.enhfiles) == 1: + args.permutation = False + + # Read text files and created a mapping of key2filepath + reffiles_dict = OrderedDict() # Dict[str, Dict[str, str]] + for ref in args.reffiles: + d = OrderedDict() + with open(ref, 'r') as f: + for line in f: + key, path = line.split(None, 1) + d[key] = path.rstrip() + reffiles_dict[ref] = d + + enhfiles_dict = OrderedDict() # Dict[str, Dict[str, str]] + for enh in args.enhfiles: + d = OrderedDict() + with open(enh, 'r') as f: + for line in f: + key, path = line.split(None, 1) + d[key] = path.rstrip() + enhfiles_dict[enh] = d + + if args.keylist is not None: + with open(args.keylist, 'r') as f: + keylist = [line.rstrip().split()[0] for line in f] + else: + keylist = list(reffiles_dict.values())[0] + + if len(keylist) == 0: + raise RuntimeError('No keys are found') + + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + evaltypes = [] + for evaltype in args.evaltypes: + if evaltype == 'SDR': + evaltypes += ['SDR', 'ISR', 'SIR', 'SAR'] + else: + evaltypes.append(evaltype) + + # Open files in write mode + writers = {k: open(os.path.join(args.outdir, k), 'w') for k in evaltypes} + + for key in keylist: + # 1. Load ref files + rate_prev = None + + ref_signals = [] + for listname, d in reffiles_dict.items(): + if key not in d: + raise RuntimeError('{} doesn\'t exist in {}' + .format(key, listname)) + filepath = d[key] + signal, rate = soundfile.read(filepath, dtype=np.int16) + if signal.ndim == 1: + # (Nframe) -> (Nframe, 1) + signal = signal[:, None] + ref_signals.append(signal) + if rate_prev is not None and rate != rate_prev: + raise RuntimeError('Sampling rates mismatch') + rate_prev = rate + + # 2. Load enh files + enh_signals = [] + for listname, d in enhfiles_dict.items(): + if key not in d: + raise RuntimeError('{} doesn\'t exist in {}' + .format(key, listname)) + filepath = d[key] + signal, rate = soundfile.read(filepath, dtype=np.int16) + if signal.ndim == 1: + # (Nframe) -> (Nframe, 1) + signal = signal[:, None] + enh_signals.append(signal) + if rate_prev is not None and rate != rate_prev: + raise RuntimeError('Sampling rates mismatch') + rate_prev = rate + + for signal in ref_signals + enh_signals: + if signal.shape[1] != ref_signals[0].shape[1]: + raise RuntimeError('The number of channels mismatch') + + # 3. Zero padding to adjust the length to the maximum length in inputs + ml = max(len(s) for s in ref_signals + enh_signals) + ref_signals = [np.pad(s, [(0, ml - len(s)), (0, 0)], mode='constant') + if len(s) < ml else s for s in ref_signals] + + enh_signals = [np.pad(s, [(0, ml - len(s)), (0, 0)], mode='constant') + if len(s) < ml else s for s in enh_signals] + + # ref_signals, enh_signals: (Nsrc, Nframe, Nmic) + ref_signals = np.stack(ref_signals, axis=0) + enh_signals = np.stack(enh_signals, axis=0) + + # 4. Evaluates + for evaltype in args.evaltypes: + if evaltype == 'SDR': + if args.source_image: + (sdr, isr, sir, sar, perm) = \ + museval.metrics.bss_eval_images( + ref_signals, enh_signals, + compute_permutation=args.permutation) + else: + (sdr, sir, sar, perm) = \ + museval.metrics.bss_eval_sources( + ref_signals, enh_signals, + compute_permutation=args.permutation) + isr = np.array([[0] for _ in range(len(sdr))]) + + # sdr: (Nsrc, Nframe) + writers['SDR'].write( + '{} {}\n'.format(key, ' '.join(map(str, sdr[:, 0])))) + writers['ISR'].write( + '{} {}\n'.format(key, ' '.join(map(str, isr[:, 0])))) + writers['SIR'].write( + '{} {}\n'.format(key, ' '.join(map(str, sir[:, 0])))) + writers['SAR'].write( + '{} {}\n'.format(key, ' '.join(map(str, sar[:, 0])))) + + elif evaltype == 'STOI': + stoi, perm = eval_STOI(ref_signals, enh_signals, rate, + extended=False, + compute_permutation=args.permutation) + writers['STOI'].write( + '{} {}\n'.format(key, ' '.join(map(str, stoi)))) + + elif evaltype == 'ESTOI': + estoi, perm = eval_STOI(ref_signals, enh_signals, rate, + extended=True, + compute_permutation=args.permutation) + writers['ESTOI'].write( + '{} {}\n'.format(key, ' '.join(map(str, estoi)))) + + elif evaltype == 'PESQ': + pesq, perm = eval_PESQ(ref_signals, enh_signals, rate, + compute_permutation=args.permutation) + writers['PESQ'].write( + '{} {}\n'.format(key, ' '.join(map(str, pesq)))) + else: + # Cannot reach + raise RuntimeError + + +if __name__ == "__main__": + main() diff --git a/utils/eval_source_separation.sh b/utils/eval_source_separation.sh new file mode 100755 index 00000000000..319a063ba7c --- /dev/null +++ b/utils/eval_source_separation.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +echo "$0 $*" >&2 # Print the command line for logging + +nj=10 +cmd=run.pl +evaltypes="SDR STOI ESTOI" +permutation=true +source_image=true + +. ./path.sh +. utils/parse_options.sh + +if [ $# != 3 ]; then + cat << EOF 1>&2 +Usage: $0 reffiles enffiles + e.g. $0 reference.scp enhanced.scp outdir + +And also supporting multiple sources: + e.g. $0 "ref1.scp,ref2.scp" "enh1.scp,enh2.scp" outdir + +Options: + --nj # number of parallel jobs + --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs. +EOF + exit 1; +fi + +set -euo pipefail + +reffiles=( $(echo $1 | tr , " ") ) +enhfiles=( $(echo $2 | tr , " ") ) +dir=$3 +logdir=${dir}/log +mkdir -p ${logdir} + +split_scps="" +for n in $(seq ${nj}); do + split_scps="${split_scps} ${logdir}/key.${n}.scp" +done + +# Split the first reference +utils/split_scp.pl ${reffiles[0]} ${split_scps} || exit 1; + +${cmd} JOB=1:${nj} ${logdir}/eval-enhanced-speech.JOB.log \ + eval-source-separation.py \ + --ref ${reffiles[@]} --enh ${enhfiles[@]} \ + --keylist ${logdir}/key.JOB.scp \ + --out ${logdir}/JOB \ + --evaltypes ${evaltypes} \ + --permutation ${permutation} \ + --source-image ${source_image} + + +for t in $(echo ${evaltypes} | sed -e 's/SDR/SDR ISR SIR SAR/'); do + for i in $(seq 1 ${nj}); do + cat ${logdir}/${i}/${t} + done > ${dir}/${t} + + # Calculate the mean over files + python << EOF > ${dir}/mean_${t} +with open('${dir}/${t}', 'r') as f: + values = [] + for l in f: + vs = l.rstrip().split(None)[1:] + values.append(sum(map(float, vs)) / len(vs)) + mean = sum(values) / len(values) +print(mean) +EOF + +done +