diff --git a/audiolm_pytorch/__init__.py b/audiolm_pytorch/__init__.py index 512fbd2..f8447c9 100644 --- a/audiolm_pytorch/__init__.py +++ b/audiolm_pytorch/__init__.py @@ -5,3 +5,4 @@ from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec +from audiolm_pytorch.hubert_kmeans import HubertWithKmeans diff --git a/audiolm_pytorch/audiolm_pytorch.py b/audiolm_pytorch/audiolm_pytorch.py index 335bff6..cc7638f 100644 --- a/audiolm_pytorch/audiolm_pytorch.py +++ b/audiolm_pytorch/audiolm_pytorch.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Optional +from typing import Optional, Union import torch from torch import nn, einsum @@ -12,6 +12,7 @@ from vector_quantize_pytorch import ResidualVQ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec +from audiolm_pytorch.hubert_kmeans import HubertWithKmeans # helper functions @@ -551,7 +552,7 @@ def __init__( *, num_semantic_tokens, dim, - wav2vec: Optional[FairseqVQWav2Vec] = None, + wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, **kwargs ): super().__init__() @@ -574,8 +575,8 @@ def forward( if not exists(ids): assert exists(self.wav2vec) - ids = self.wav2vec(raw_wave) - + ids = self.wav2vec(raw_wave, flatten = False) + if return_loss: labels, ids = ids.clone(), ids[:, :-1] @@ -606,7 +607,7 @@ def __init__( codebook_size, num_coarse_quantizers, dim, - wav2vec: Optional[FairseqVQWav2Vec] = None, + wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, **kwargs ): super().__init__() @@ -840,7 +841,7 @@ def __init__( *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, - wav2vec: Optional[FairseqVQWav2Vec] = None, + wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, num_coarse_quantize = 3 ): super().__init__() @@ -866,7 +867,7 @@ def forward( if not exists(semantic_token_ids): assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training' - semantic_token_ids = self.wav2vec(raw_wave) + semantic_token_ids = self.wav2vec(raw_wave, flatten = False) if not exists(coarse_token_ids): assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training' diff --git a/audiolm_pytorch/hubert_kmeans.py b/audiolm_pytorch/hubert_kmeans.py new file mode 100644 index 0000000..458f98c --- /dev/null +++ b/audiolm_pytorch/hubert_kmeans.py @@ -0,0 +1,56 @@ +from pathlib import Path + +import torch +from torch import nn +from einops import rearrange, pack, unpack + +import joblib +import fairseq + +class HubertWithKmeans(nn.Module): + def __init__( + self, + checkpoint_path, + kmeans_path + ): + super().__init__() + model_path = Path(checkpoint_path) + kmeans_path = Path(kmeans_path) + + assert model_path.exists(), f'path {checkpoint_path} does not exist' + assert kmeans_path.exists(), f'path {kmeans_path} does not exist' + + checkpoint = torch.load(checkpoint_path) + load_model_input = {checkpoint_path: checkpoint} + model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) + + self.model = model[0] + self.model.eval() + + kmeans = joblib.load(kmeans_path) + self.kmeans = kmeans + + @property + def groups(self): + return 1 + + @property + def codebook_size(self): + return self.kmeans.n_clusters + + @torch.no_grad() + def forward(self, wav_input, flatten = True): + device = wav_input.device + + embed = self.model(wav_input, features_only = True) + embed, packed_shape = pack([embed['x']], '* d') + + codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) + + codebook_indices = torch.from_numpy(codebook_indices).to(device).long() + + if flatten: + return codebook_indices + + codebook_indices, = unpack(codebook_indices, packed_shape, '*') + return codebook_indices diff --git a/setup.py b/setup.py index 154e6cd..e684903 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.5', + version = '0.0.6', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', @@ -19,9 +19,10 @@ ], install_requires=[ 'accelerate', - 'einops>=0.5', + 'einops>=0.6', 'ema-pytorch', 'fairseq', + 'joblib', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.5' ],