Skip to content

Commit

Permalink
listen to @eonglints and add hubert with kmeans as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 10, 2022
1 parent ed313d3 commit a11722e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
1 change: 1 addition & 0 deletions audiolm_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
15 changes: 8 additions & 7 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from functools import partial
from typing import Optional
from typing import Optional, Union

import torch
from torch import nn, einsum
Expand All @@ -12,6 +12,7 @@
from vector_quantize_pytorch import ResidualVQ

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

# helper functions

Expand Down Expand Up @@ -551,7 +552,7 @@ def __init__(
*,
num_semantic_tokens,
dim,
wav2vec: Optional[FairseqVQWav2Vec] = None,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
**kwargs
):
super().__init__()
Expand All @@ -574,8 +575,8 @@ def forward(

if not exists(ids):
assert exists(self.wav2vec)
ids = self.wav2vec(raw_wave)

ids = self.wav2vec(raw_wave, flatten = False)
if return_loss:
labels, ids = ids.clone(), ids[:, :-1]

Expand Down Expand Up @@ -606,7 +607,7 @@ def __init__(
codebook_size,
num_coarse_quantizers,
dim,
wav2vec: Optional[FairseqVQWav2Vec] = None,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -840,7 +841,7 @@ def __init__(
*,
transformer: FineTransformer,
soundstream: Optional[SoundStream] = None,
wav2vec: Optional[FairseqVQWav2Vec] = None,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
num_coarse_quantize = 3
):
super().__init__()
Expand All @@ -866,7 +867,7 @@ def forward(

if not exists(semantic_token_ids):
assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
semantic_token_ids = self.wav2vec(raw_wave)
semantic_token_ids = self.wav2vec(raw_wave, flatten = False)

if not exists(coarse_token_ids):
assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'
Expand Down
56 changes: 56 additions & 0 deletions audiolm_pytorch/hubert_kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pathlib import Path

import torch
from torch import nn
from einops import rearrange, pack, unpack

import joblib
import fairseq

class HubertWithKmeans(nn.Module):
def __init__(
self,
checkpoint_path,
kmeans_path
):
super().__init__()
model_path = Path(checkpoint_path)
kmeans_path = Path(kmeans_path)

assert model_path.exists(), f'path {checkpoint_path} does not exist'
assert kmeans_path.exists(), f'path {kmeans_path} does not exist'

checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

self.model = model[0]
self.model.eval()

kmeans = joblib.load(kmeans_path)
self.kmeans = kmeans

@property
def groups(self):
return 1

@property
def codebook_size(self):
return self.kmeans.n_clusters

@torch.no_grad()
def forward(self, wav_input, flatten = True):
device = wav_input.device

embed = self.model(wav_input, features_only = True)
embed, packed_shape = pack([embed['x']], '* d')

codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())

codebook_indices = torch.from_numpy(codebook_indices).to(device).long()

if flatten:
return codebook_indices

codebook_indices, = unpack(codebook_indices, packed_shape, '*')
return codebook_indices
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand All @@ -19,9 +19,10 @@
],
install_requires=[
'accelerate',
'einops>=0.5',
'einops>=0.6',
'ema-pytorch',
'fairseq',
'joblib',
'torch>=1.6',
'vector-quantize-pytorch>=0.10.5'
],
Expand Down

1 comment on commit a11722e

@eonglints
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Please sign in to comment.