Skip to content

Commit

Permalink
integrate residual lookup free quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 22, 2023
1 parent 4bc50b0 commit c748fcd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 21 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ from audiolm_pytorch import SoundStream, SoundStreamTrainer
soundstream = SoundStream(
codebook_size = 1024,
rq_num_quantizers = 8,
rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
attn_window_size = 128, # local attention receptive field at bottleneck
attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
use_lookup_free_quantizer = True, # whether to use residual lookup free quantization
attn_window_size = 128, # local attention receptive field at bottleneck
attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)

trainer = SoundStreamTrainer(
Expand Down Expand Up @@ -509,3 +510,14 @@ $ accelerate launch train.py
year = {2022}
}
```

```bibtex
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
48 changes: 32 additions & 16 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from einops import rearrange, reduce, pack, unpack

from vector_quantize_pytorch import GroupedResidualVQ
from vector_quantize_pytorch import (
GroupedResidualVQ,
ResidualLFQ
)

from local_attention import LocalMHA
from local_attention.transformer import FeedForward, DynamicPositionBias
Expand Down Expand Up @@ -433,6 +436,7 @@ def __init__(
rq_groups = 1,
rq_stochastic_sample_codes = False,
rq_kwargs: dict = {},
use_lookup_free_quantizer = True, # proposed in https://arxiv.org/abs/2310.05737, adapted in residual quantization fashion for audio
input_channels = 1,
discr_multi_scales = (1, 0.5, 0.25),
stft_normalized = False,
Expand Down Expand Up @@ -513,21 +517,33 @@ def __init__(

self.rq_groups = rq_groups

self.rq = GroupedResidualVQ(
dim = codebook_dim,
num_quantizers = rq_num_quantizers,
codebook_size = codebook_size,
groups = rq_groups,
decay = rq_ema_decay,
commitment_weight = rq_commitment_weight,
quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of,
kmeans_init = True,
threshold_ema_dead_code = 2,
quantize_dropout = True,
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
stochastic_sample_codes = rq_stochastic_sample_codes,
**rq_kwargs
)
if use_lookup_free_quantizer:
assert rq_groups == 1, 'grouped residual LFQ not implemented yet'

self.rq = ResidualLFQ(
dim = codebook_dim,
num_quantizers = rq_num_quantizers,
codebook_size = codebook_size,
quantize_dropout = True,
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
**rq_kwargs
)
else:
self.rq = GroupedResidualVQ(
dim = codebook_dim,
num_quantizers = rq_num_quantizers,
codebook_size = codebook_size,
groups = rq_groups,
decay = rq_ema_decay,
commitment_weight = rq_commitment_weight,
quantize_dropout_multiple_of = rq_quantize_dropout_multiple_of,
kmeans_init = True,
threshold_ema_dead_code = 2,
quantize_dropout = True,
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index,
stochastic_sample_codes = rq_stochastic_sample_codes,
**rq_kwargs
)

self.decoder_film = FiLM(codebook_dim, dim_cond = 2)

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.7'
__version__ = '1.6.0'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
'torchaudio',
'transformers',
'tqdm',
'vector-quantize-pytorch>=1.7.0'
'vector-quantize-pytorch>=1.10.2'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit c748fcd

Please sign in to comment.