Skip to content

Commit

Permalink
Merge pull request #12 from gudgud96/feat/float32-wav
Browse files Browse the repository at this point in the history
Support float32 PCM wav files
  • Loading branch information
gudgud96 authored Jul 17, 2023
2 parents f92b37d + 07833f1 commit a116459
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 Hao Hao Tan
Copyright (c) 2023 Hao Hao Tan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ frechet = FrechetAudioDistance(
model_name="vggish",
use_pca=False,
use_activation=False,
verbose=False
verbose=False,
dtype="float32"
)
# to use `PANN`
frechet = FrechetAudioDistance(
model_name="pann",
use_pca=False,
use_activation=False,
verbose=False
verbose=False,
dtype="float32"
)
fad_score = frechet.score("/path/to/background/set", "/path/to/eval/set")

Expand All @@ -50,6 +52,11 @@ FAD scores comparison w.r.t. to original implementation in `google-research/frec
|:----------------------------:|:---------------------:|:------------------------:|
| `frechet_audio_distance` | 0.000465 | 0.00008594 |

### To contribute

- Run `python3 -m build` to build your version locally. The built wheel should be in `dist/`.
- `pip install` your local wheel version, and run `pytest test/` to validate your changes.

### References

VGGish in PyTorch: https://github.com/harritaylor/torchvggish
Expand Down
45 changes: 34 additions & 11 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
SAMPLE_RATE = 16000


def load_audio_task(fname):
wav_data, sr = sf.read(fname, dtype='int16')
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0]

def load_audio_task(fname, dtype="float32"):
if dtype not in ['float64', 'float32', 'int32', 'int16']:
raise ValueError(f"dtype not supported: {dtype}")

wav_data, sr = sf.read(fname, dtype=dtype)
# For integer type PCM input, convert to [-1.0, +1.0]
if dtype == 'int16':
wav_data = wav_data / 32768.0
elif dtype == 'int32':
wav_data = wav_data / float(2**31)

# Convert to mono
if len(wav_data.shape) > 1:
wav_data = np.mean(wav_data, axis=1)
Expand All @@ -36,7 +42,14 @@ def load_audio_task(fname):


class FrechetAudioDistance:
def __init__(self, model_name="vggish", use_pca=False, use_activation=False, verbose=False, audio_load_worker=8):
def __init__(
self,
model_name="vggish",
use_pca=False,
use_activation=False,
verbose=False,
audio_load_worker=8
):
self.model_name = model_name
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.__get_model(model_name=model_name, use_pca=use_pca, use_activation=use_activation)
Expand Down Expand Up @@ -157,7 +170,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
return (diff.dot(diff) + np.trace(sigma1)
+ np.trace(sigma2) - 2 * tr_covmean)

def __load_audio_files(self, dir):
def __load_audio_files(self, dir, dtype="float32"):
task_results = []

pool = ThreadPool(self.audio_load_worker)
Expand All @@ -169,17 +182,27 @@ def update(*a):
if self.verbose:
print("[Frechet Audio Distance] Loading audio from {}...".format(dir))
for fname in os.listdir(dir):
res = pool.apply_async(load_audio_task, args=(os.path.join(dir, fname),), callback=update)
res = pool.apply_async(
load_audio_task,
args=(os.path.join(dir, fname), dtype,),
callback=update
)
task_results.append(res)
pool.close()
pool.join()

return [k.get() for k in task_results]

def score(self, background_dir, eval_dir, store_embds=False):
def score(
self,
background_dir,
eval_dir,
store_embds=False,
dtype="float32"
):
try:
audio_background = self.__load_audio_files(background_dir)
audio_eval = self.__load_audio_files(eval_dir)
audio_background = self.__load_audio_files(background_dir, dtype=dtype)
audio_eval = self.__load_audio_files(eval_dir, dtype=dtype)

embds_background = self.get_embeddings(audio_background)
embds_eval = self.get_embeddings(audio_eval)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "frechet_audio_distance"
version = "0.1.0"
version = "0.1.1"
authors = [
{ name="Hao Hao Tan", email="helloharry66@gmail.com" },
]
Expand All @@ -24,6 +24,7 @@ dependencies = [
'tqdm',
'soundfile',
'resampy',
'torchlibrosa'
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ torch
scipy
tqdm
soundfile
resampy
resampy
torchlibrosa

0 comments on commit a116459

Please sign in to comment.