From 02f468ea91939dc9679d4697cda185e1aa8935ed Mon Sep 17 00:00:00 2001 From: gudgud96 Date: Mon, 17 Jul 2023 11:51:05 +0800 Subject: [PATCH 1/2] [feat] support float32 PCM wav files --- frechet_audio_distance/fad.py | 45 ++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/frechet_audio_distance/fad.py b/frechet_audio_distance/fad.py index 8f51abf..9914e99 100644 --- a/frechet_audio_distance/fad.py +++ b/frechet_audio_distance/fad.py @@ -20,11 +20,17 @@ SAMPLE_RATE = 16000 -def load_audio_task(fname): - wav_data, sr = sf.read(fname, dtype='int16') - assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype - wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0] - +def load_audio_task(fname, dtype="float32"): + if dtype not in ['float64', 'float32', 'int32', 'int16']: + raise ValueError(f"dtype not supported: {dtype}") + + wav_data, sr = sf.read(fname, dtype=dtype) + # For integer type PCM input, convert to [-1.0, +1.0] + if dtype == 'int16': + wav_data = wav_data / 32768.0 + elif dtype == 'int32': + wav_data = wav_data / float(2**31) + # Convert to mono if len(wav_data.shape) > 1: wav_data = np.mean(wav_data, axis=1) @@ -36,7 +42,14 @@ def load_audio_task(fname): class FrechetAudioDistance: - def __init__(self, model_name="vggish", use_pca=False, use_activation=False, verbose=False, audio_load_worker=8): + def __init__( + self, + model_name="vggish", + use_pca=False, + use_activation=False, + verbose=False, + audio_load_worker=8 + ): self.model_name = model_name self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.__get_model(model_name=model_name, use_pca=use_pca, use_activation=use_activation) @@ -157,7 +170,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) - def __load_audio_files(self, dir): + def __load_audio_files(self, dir, dtype="float32"): task_results = [] pool = ThreadPool(self.audio_load_worker) @@ -169,17 +182,27 @@ def update(*a): if self.verbose: print("[Frechet Audio Distance] Loading audio from {}...".format(dir)) for fname in os.listdir(dir): - res = pool.apply_async(load_audio_task, args=(os.path.join(dir, fname),), callback=update) + res = pool.apply_async( + load_audio_task, + args=(os.path.join(dir, fname), dtype,), + callback=update + ) task_results.append(res) pool.close() pool.join() return [k.get() for k in task_results] - def score(self, background_dir, eval_dir, store_embds=False): + def score( + self, + background_dir, + eval_dir, + store_embds=False, + dtype="float32" + ): try: - audio_background = self.__load_audio_files(background_dir) - audio_eval = self.__load_audio_files(eval_dir) + audio_background = self.__load_audio_files(background_dir, dtype=dtype) + audio_eval = self.__load_audio_files(eval_dir, dtype=dtype) embds_background = self.get_embeddings(audio_background) embds_eval = self.get_embeddings(audio_eval) From 07833f12284143b274ac3e9781bdf0adb9c2fa45 Mon Sep 17 00:00:00 2001 From: gudgud96 Date: Mon, 17 Jul 2023 11:51:27 +0800 Subject: [PATCH 2/2] [chore] bump to 0.1.1 --- LICENSE | 2 +- README.md | 11 +++++++++-- pyproject.toml | 3 ++- requirements.txt | 3 ++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/LICENSE b/LICENSE index aee6d1d..fda2785 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 Hao Hao Tan +Copyright (c) 2023 Hao Hao Tan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 495db25..4331cff 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,16 @@ frechet = FrechetAudioDistance( model_name="vggish", use_pca=False, use_activation=False, - verbose=False + verbose=False, + dtype="float32" ) # to use `PANN` frechet = FrechetAudioDistance( model_name="pann", use_pca=False, use_activation=False, - verbose=False + verbose=False, + dtype="float32" ) fad_score = frechet.score("/path/to/background/set", "/path/to/eval/set") @@ -50,6 +52,11 @@ FAD scores comparison w.r.t. to original implementation in `google-research/frec |:----------------------------:|:---------------------:|:------------------------:| | `frechet_audio_distance` | 0.000465 | 0.00008594 | +### To contribute + +- Run `python3 -m build` to build your version locally. The built wheel should be in `dist/`. +- `pip install` your local wheel version, and run `pytest test/` to validate your changes. + ### References VGGish in PyTorch: https://github.com/harritaylor/torchvggish diff --git a/pyproject.toml b/pyproject.toml index 5bcf0d7..b39a996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "frechet_audio_distance" -version = "0.1.0" +version = "0.1.1" authors = [ { name="Hao Hao Tan", email="helloharry66@gmail.com" }, ] @@ -24,6 +24,7 @@ dependencies = [ 'tqdm', 'soundfile', 'resampy', + 'torchlibrosa' ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 11631bf..e0474eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ torch scipy tqdm soundfile -resampy \ No newline at end of file +resampy +torchlibrosa \ No newline at end of file