Skip to content

Commit b176292

Browse files
committed
add EMA Vector Quantizer
1 parent 9d17ea6 commit b176292

File tree

2 files changed

+204
-1
lines changed

2 files changed

+204
-1
lines changed

taming/models/vqgan.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from taming.modules.diffusionmodules.model import Encoder, Decoder
88
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
99
from taming.modules.vqvae.quantize import GumbelQuantize
10-
10+
from taming.modules.vqvae.quantize import EMAVectorQuantizer
1111

1212
class VQModel(pl.LightningModule):
1313
def __init__(self,
@@ -361,3 +361,44 @@ def log_images(self, batch, **kwargs):
361361
log["inputs"] = x
362362
log["reconstructions"] = x_rec
363363
return log
364+
365+
366+
class EMAVQ(VQModel):
367+
def __init__(self,
368+
ddconfig,
369+
lossconfig,
370+
n_embed,
371+
embed_dim,
372+
ckpt_path=None,
373+
ignore_keys=[],
374+
image_key="image",
375+
colorize_nlabels=None,
376+
monitor=None,
377+
remap=None,
378+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
379+
):
380+
super().__init__(ddconfig,
381+
lossconfig,
382+
n_embed,
383+
embed_dim,
384+
ckpt_path=None,
385+
ignore_keys=ignore_keys,
386+
image_key=image_key,
387+
colorize_nlabels=colorize_nlabels,
388+
monitor=monitor,
389+
)
390+
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
391+
embedding_dim=embed_dim,
392+
beta=0.25,
393+
remap=remap)
394+
def configure_optimizers(self):
395+
lr = self.learning_rate
396+
#Remove self.quantize from parameter list since it is updated via EMA
397+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
398+
list(self.decoder.parameters())+
399+
list(self.quant_conv.parameters())+
400+
list(self.post_quant_conv.parameters()),
401+
lr=lr, betas=(0.5, 0.9))
402+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403+
lr=lr, betas=(0.5, 0.9))
404+
return [opt_ae, opt_disc], []

taming/modules/vqvae/quantize.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,165 @@ def get_codebook_entry(self, indices, shape):
327327
z_q = z_q.permute(0, 3, 1, 2).contiguous()
328328

329329
return z_q
330+
331+
332+
333+
class EMAVectorQuantizer(nn.Module):
334+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
335+
remap=None, unknown_index="random"):
336+
super().__init__()
337+
self.embedding_dim = embedding_dim
338+
self.n_embed = n_embed
339+
self.decay = decay
340+
self.eps = eps
341+
self.beta = beta
342+
self.embedding = nn.Embedding(self.n_embed, self.embedding_dim)
343+
self.embedding.weight.requires_grad = False
344+
self.cluster_size = nn.Parameter(torch.zeros(n_embed),requires_grad=False)
345+
self.embed_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embedding_dim),requires_grad=False)
346+
self.embed_avg.data.copy_(self.embedding.weight.data)
347+
self.remap = remap
348+
if self.remap is not None:
349+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
350+
self.re_embed = self.used.shape[0]
351+
self.unknown_index = unknown_index # "random" or "extra" or integer
352+
if self.unknown_index == "extra":
353+
self.unknown_index = self.re_embed
354+
self.re_embed = self.re_embed+1
355+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
356+
f"Using {self.unknown_index} for unknown indices.")
357+
else:
358+
self.re_embed = n_embed
359+
360+
def remap_to_used(self, inds):
361+
ishape = inds.shape
362+
assert len(ishape)>1
363+
inds = inds.reshape(ishape[0],-1)
364+
used = self.used.to(inds)
365+
match = (inds[:,:,None]==used[None,None,...]).long()
366+
new = match.argmax(-1)
367+
unknown = match.sum(2)<1
368+
if self.unknown_index == "random":
369+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
370+
else:
371+
new[unknown] = self.unknown_index
372+
return new.reshape(ishape)
373+
374+
def unmap_to_all(self, inds):
375+
ishape = inds.shape
376+
assert len(ishape)>1
377+
inds = inds.reshape(ishape[0],-1)
378+
used = self.used.to(inds)
379+
if self.re_embed > self.used.shape[0]: # extra token
380+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
381+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
382+
return back.reshape(ishape)
383+
384+
def forward(self, z):
385+
# reshape z -> (batch, height, width, channel) and flatten
386+
#z, 'b c h w -> b h w c'
387+
z = z.permute(0, 2, 3, 1).contiguous()
388+
z_flattened = z.view(-1, self.embedding_dim)
389+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
390+
391+
d = torch.sum(z_flattened.pow(2), dim=1, keepdim=True) + \
392+
torch.sum(self.embedding.weight.pow(2), dim=1) - 2 * \
393+
torch.einsum('bd,dn->bn', z_flattened, self.embedding.weight.permute(1,0)) # 'n d -> d n'
394+
395+
encoding_indices = torch.argmin(d, dim=1)
396+
z_q = self.embedding(encoding_indices).view(z.shape)
397+
encodings = F.one_hot(encoding_indices, self.n_embed).type(z.dtype)
398+
avg_probs = torch.mean(encodings, dim=0)
399+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
400+
401+
if self.training:
402+
encodings_sum = encodings.sum(0)
403+
#EMA cluster size
404+
self.cluster_size.mul_(self.decay).add_(encodings_sum, alpha=1 - self.decay)
405+
406+
embed_sum = torch.matmul(encodings.t(), z_flattened)
407+
#EMA embedding average
408+
self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
409+
410+
#cluster size Laplace smoothing
411+
n = self.cluster_size.sum()
412+
cluster_size = (
413+
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
414+
)
415+
#normalize embedding average with smoothed cluster size
416+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
417+
self.embedding.weight.data.copy_(embed_normalized.data)
418+
419+
# compute loss for embedding
420+
loss = self.beta * F.mse_loss(z_q.detach(), z)
421+
422+
# preserve gradients
423+
z_q = z + (z_q - z).detach()
424+
425+
# reshape back to match original input shape
426+
#z_q, 'b h w c -> b c h w'
427+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
428+
return z_q, loss, (perplexity, encodings, encoding_indices)
429+
430+
431+
432+
#Original Sonnet version of EMAVectorQuantizer
433+
class EmbeddingEMA(nn.Module):
434+
def __init__(self, n_embed, embedding_dim):
435+
super().__init__()
436+
weight = torch.randn(embedding_dim, n_embed)
437+
self.register_buffer("weight", weight)
438+
self.register_buffer("cluster_size", torch.zeros(n_embed))
439+
self.register_buffer("embed_avg", weight.clone())
440+
441+
def forward(self, embed_id):
442+
return F.embedding(embed_id, self.weight.transpose(0, 1))
443+
444+
445+
class SonnetEMAVectorQuantizer(nn.Module):
446+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
447+
remap=None, unknown_index="random"):
448+
super().__init__()
449+
self.embedding_dim = embedding_dim
450+
self.n_embed = n_embed
451+
self.decay = decay
452+
self.eps = eps
453+
self.beta = beta
454+
self.embedding = EmbeddingEMA(n_embed,embedding_dim)
455+
456+
def forward(self, z):
457+
z = z.permute(0, 2, 3, 1).contiguous()
458+
z_flattened = z.reshape(-1, self.embedding_dim)
459+
d = (
460+
z_flattened.pow(2).sum(1, keepdim=True)
461+
- 2 * z_flattened @ self.embedding.weight
462+
+ self.embedding.weight.pow(2).sum(0, keepdim=True)
463+
)
464+
_, encoding_indices = (-d).max(1)
465+
encodings = F.one_hot(encoding_indices, self.n_embed).type(z_flattened.dtype)
466+
encoding_indices = encoding_indices.view(*z.shape[:-1])
467+
z_q = self.embedding(encoding_indices)
468+
avg_probs = torch.mean(encodings, dim=0)
469+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
470+
471+
if self.training:
472+
encodings_sum = encodings.sum(0)
473+
embed_sum = z_flattened.transpose(0, 1) @ encodings
474+
#EMA cluster size
475+
self.embedding.cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1 - self.decay)
476+
#EMA embedding average
477+
self.embedding.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
478+
479+
#cluster size Laplace smoothing
480+
n = self.embedding.cluster_size.sum()
481+
cluster_size = (
482+
(self.embedding.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
483+
)
484+
#normalize embedding average with smoothed cluster size
485+
embed_normalized = self.embedding.embed_avg / cluster_size.unsqueeze(0)
486+
self.embedding.weight.data.copy_(embed_normalized)
487+
488+
loss = self.beta * (z_q.detach() - z).pow(2).mean()
489+
z_q = z + (z_q - z).detach()
490+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
491+
return z_q, loss, (perplexity, encodings, encoding_indices)

0 commit comments

Comments
 (0)