@@ -327,3 +327,119 @@ def get_codebook_entry(self, indices, shape):
327327 z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
328328
329329 return z_q
330+
331+ class EmbeddingEMA (nn .Module ):
332+ def __init__ (self , num_tokens , codebook_dim , decay = 0.99 , eps = 1e-5 ):
333+ super ().__init__ ()
334+ self .decay = decay
335+ self .eps = eps
336+ weight = torch .randn (num_tokens , codebook_dim )
337+ self .weight = nn .Parameter (weight , requires_grad = False )
338+ self .cluster_size = nn .Parameter (torch .zeros (num_tokens ), requires_grad = False )
339+ self .embed_avg = nn .Parameter (weight .clone (), requires_grad = False )
340+ self .update = True
341+
342+ def forward (self , embed_id ):
343+ return F .embedding (embed_id , self .weight )
344+
345+ def cluster_size_ema_update (self , new_cluster_size ):
346+ self .cluster_size .data .mul_ (self .decay ).add_ (new_cluster_size , alpha = 1 - self .decay )
347+
348+ def embed_avg_ema_update (self , new_embed_avg ):
349+ self .embed_avg .data .mul_ (self .decay ).add_ (new_embed_avg , alpha = 1 - self .decay )
350+
351+ def weight_update (self , num_tokens ):
352+ n = self .cluster_size .sum ()
353+ smoothed_cluster_size = (
354+ (self .cluster_size + self .eps ) / (n + num_tokens * self .eps ) * n
355+ )
356+ #normalize embedding average with smoothed cluster size
357+ embed_normalized = self .embed_avg / smoothed_cluster_size .unsqueeze (1 )
358+ self .weight .data .copy_ (embed_normalized )
359+
360+
361+ class EMAVectorQuantizer (nn .Module ):
362+ def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
363+ remap = None , unknown_index = "random" ):
364+ super ().__init__ ()
365+ self .codebook_dim = codebook_dim
366+ self .num_tokens = num_tokens
367+ self .beta = beta
368+ self .embedding = EmbeddingEMA (self .num_tokens , self .codebook_dim , decay , eps )
369+
370+ self .remap = remap
371+ if self .remap is not None :
372+ self .register_buffer ("used" , torch .tensor (np .load (self .remap )))
373+ self .re_embed = self .used .shape [0 ]
374+ self .unknown_index = unknown_index # "random" or "extra" or integer
375+ if self .unknown_index == "extra" :
376+ self .unknown_index = self .re_embed
377+ self .re_embed = self .re_embed + 1
378+ print (f"Remapping { self .n_embed } indices to { self .re_embed } indices. "
379+ f"Using { self .unknown_index } for unknown indices." )
380+ else :
381+ self .re_embed = n_embed
382+
383+ def remap_to_used (self , inds ):
384+ ishape = inds .shape
385+ assert len (ishape )> 1
386+ inds = inds .reshape (ishape [0 ],- 1 )
387+ used = self .used .to (inds )
388+ match = (inds [:,:,None ]== used [None ,None ,...]).long ()
389+ new = match .argmax (- 1 )
390+ unknown = match .sum (2 )< 1
391+ if self .unknown_index == "random" :
392+ new [unknown ]= torch .randint (0 ,self .re_embed ,size = new [unknown ].shape ).to (device = new .device )
393+ else :
394+ new [unknown ] = self .unknown_index
395+ return new .reshape (ishape )
396+
397+ def unmap_to_all (self , inds ):
398+ ishape = inds .shape
399+ assert len (ishape )> 1
400+ inds = inds .reshape (ishape [0 ],- 1 )
401+ used = self .used .to (inds )
402+ if self .re_embed > self .used .shape [0 ]: # extra token
403+ inds [inds >= self .used .shape [0 ]] = 0 # simply set to zero
404+ back = torch .gather (used [None ,:][inds .shape [0 ]* [0 ],:], 1 , inds )
405+ return back .reshape (ishape )
406+
407+ def forward (self , z ):
408+ # reshape z -> (batch, height, width, channel) and flatten
409+ #z, 'b c h w -> b h w c'
410+ z = rearrange (z , 'b c h w -> b h w c' )
411+ z_flattened = z .reshape (- 1 , self .codebook_dim )
412+
413+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414+ d = z_flattened .pow (2 ).sum (dim = 1 , keepdim = True ) + \
415+ self .embedding .weight .pow (2 ).sum (dim = 1 ) - 2 * \
416+ torch .einsum ('bd,nd->bn' , z_flattened , self .embedding .weight ) # 'n d -> d n'
417+
418+
419+ encoding_indices = torch .argmin (d , dim = 1 )
420+
421+ z_q = self .embedding (encoding_indices ).view (z .shape )
422+ encodings = F .one_hot (encoding_indices , self .num_tokens ).type (z .dtype )
423+ avg_probs = torch .mean (encodings , dim = 0 )
424+ perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
425+
426+ if self .training and self .embedding .update :
427+ #EMA cluster size
428+ encodings_sum = encodings .sum (0 )
429+ self .embedding .cluster_size_ema_update (encodings_sum )
430+ #EMA embedding average
431+ embed_sum = encodings .transpose (0 ,1 ) @ z_flattened
432+ self .embedding .embed_avg_ema_update (embed_sum )
433+ #normalize embed_avg and update weight
434+ self .embedding .weight_update (self .num_tokens )
435+
436+ # compute loss for embedding
437+ loss = self .beta * F .mse_loss (z_q .detach (), z )
438+
439+ # preserve gradients
440+ z_q = z + (z_q - z ).detach ()
441+
442+ # reshape back to match original input shape
443+ #z_q, 'b h w c -> b c h w'
444+ z_q = rearrange (z_q , 'b h w c -> b c h w' )
445+ return z_q , loss , (perplexity , encodings , encoding_indices )
0 commit comments