Skip to content

residual lfq #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 21, 2023
Merged
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,35 @@ assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
```

An improvised Residual LFQ, to see if it can lead to an improvement for audio compression.

```python
import torch
from vector_quantize_pytorch import ResidualLFQ

residual_lfq = ResidualLFQ(
dim = 256,
codebook_size = 256,
num_quantizers = 8
)

x = torch.randn(1, 1024, 256)

residual_lfq.eval()

quantized, indices, commit_loss = residual_lfq(x)

# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)

quantized_out = residual_lfq.get_codes_from_indices(indices)

# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)

assert torch.all(quantized == residual_lfq.project_out(quantized_out.sum(dim = 0)))
```

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '1.9.18',
version = '1.10.0',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
1 change: 1 addition & 0 deletions vector_quantize_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
from vector_quantize_pytorch.lookup_free_quantization import LFQ
from vector_quantize_pytorch.residual_lfq import ResidualLFQ
19 changes: 14 additions & 5 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(
diversity_gamma = 2.5,
straight_through_activation = nn.Identity(),
num_codebooks = 1,
keep_num_codebooks_dim = None
keep_num_codebooks_dim = None,
codebook_scale = 1. # for residual LFQ, codebook scaled down by 2x at each layer
):
super().__init__()

Expand Down Expand Up @@ -103,6 +104,10 @@ def __init__(
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight

# codebook scale

self.codebook_scale = codebook_scale

# commitment loss

self.commitment_loss_weight = commitment_loss_weight
Expand All @@ -116,10 +121,13 @@ def __init__(

all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = bits * 2 - 1
codebook = self.bits_to_codes(bits)

self.register_buffer('codebook', codebook, persistent = False)

def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale

@property
def dtype(self):
return self.codebook.dtype
Expand All @@ -137,7 +145,8 @@ def indices_to_codes(
# indices to codes, which are bits of either -1 or 1

bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
codes = bits * 2 - 1

codes = self.bits_to_codes(bits)

codes = rearrange(codes, '... c d -> ... (c d)')

Expand Down Expand Up @@ -188,8 +197,8 @@ def forward(

original_input = x

ones = torch.ones_like(x)
quantized = torch.where(x > 0, ones, -ones)
codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)

# use straight-through gradients with tanh (or custom activation fn) if training

Expand Down
180 changes: 180 additions & 0 deletions vector_quantize_pytorch/residual_lfq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from math import log2
from random import randrange
from functools import partial

import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast

from vector_quantize_pytorch.lookup_free_quantization import LFQ

from einops import rearrange, repeat, pack, unpack

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def round_up_multiple(num, mult):
return ceil(num / mult) * mult

# main class

class ResidualLFQ(Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """

def __init__(
self,
*,
dim,
num_quantizers,
codebook_size,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
**kwargs
):
super().__init__()
codebook_dim = int(log2(codebook_size))

requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()

self.num_quantizers = num_quantizers

self.layers = nn.ModuleList([])

for ind in range(num_quantizers):
codebook_scale = 2 ** -ind

lfq = LFQ(
dim = codebook_dim,
codebook_scale = codebook_scale,
**kwargs
)

self.layers.append(lfq)

self.quantize_dropout = quantize_dropout and num_quantizers > 1

assert quantize_dropout_cutoff_index >= 0

self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4

@property
def codebooks(self):
codebooks = [layer.codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks

def get_codes_from_indices(self, indices):

batch, quantize_dim = indices.shape[0], indices.shape[-1]

# may also receive indices in the shape of 'b h w q' (accept_image_fmap)

indices, ps = pack([indices], 'b * q')

# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct

if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

# get ready for gathering

codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])

# take care of quantizer dropout

mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

all_codes = codebooks.gather(2, gather_indices) # gather all codes

# mask out any codes that were dropout-ed

all_codes = all_codes.masked_fill(mask, 0.)

# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)

all_codes, = unpack(all_codes, ps, 'q b * d')

return all_codes

def forward(
self,
x,
return_all_codes = False
):
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device

x = self.project_in(x)

quantized_out = 0.
residual = x

all_losses = []
all_indices = []

should_quantize_dropout = self.training and self.quantize_dropout

# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss

if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)

if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

null_indices = torch.full(x.shape[:2], -1., device = device, dtype = torch.long)
null_loss = torch.tensor(0., device = device, dtype = x.dtype)

# go through the layers

with autocast(enabled = False):
for quantizer_index, layer in enumerate(self.layers):

if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue

quantized, indices, loss = layer(residual)

residual = residual - quantized.detach()
quantized_out = quantized_out + quantized

all_indices.append(indices)
all_losses.append(loss)

# project out, if needed

quantized_out = self.project_out(quantized_out)

# stack all losses and indices

all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))

ret = (quantized_out, all_indices, all_losses)

if not return_all_codes:
return ret

# whether to return all codes from all codebooks across layers

all_codes = self.get_codes_from_indices(all_indices)

# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)

return (*ret, all_codes)