@@ -328,21 +328,44 @@ def get_codebook_entry(self, indices, shape):
328328
329329 return z_q
330330
331+ class EmbeddingEMA (nn .Module ):
332+ def __init__ (self , num_tokens , codebook_dim , decay = 0.99 , eps = 1e-5 ):
333+ super ().__init__ ()
334+ self .decay = decay
335+ self .eps = eps
336+ weight = torch .randn (num_tokens , codebook_dim )
337+ self .weight = nn .Parameter (weight , requires_grad = False )
338+ self .cluster_size = nn .Parameter (torch .zeros (num_tokens ), requires_grad = False )
339+ self .embed_avg = nn .Parameter (weight .clone (), requires_grad = False )
340+ self .update = True
341+
342+ def forward (self , embed_id ):
343+ return F .embedding (embed_id , self .weight )
344+
345+ def cluster_size_ema_update (self , new_cluster_size ):
346+ self .cluster_size .data .mul_ (self .decay ).add_ (new_cluster_size , alpha = 1 - self .decay )
347+
348+ def embed_avg_ema_update (self , new_embed_avg ):
349+ self .embed_avg .data .mul_ (self .decay ).add_ (new_embed_avg , alpha = 1 - self .decay )
350+
351+ def weight_update (self , num_tokens ):
352+ n = self .cluster_size .sum ()
353+ smoothed_cluster_size = (
354+ (self .cluster_size + self .eps ) / (n + num_tokens * self .eps ) * n
355+ )
356+ #normalize embedding average with smoothed cluster size
357+ embed_normalized = self .embed_avg / smoothed_cluster_size .unsqueeze (1 )
358+ self .weight .data .copy_ (embed_normalized )
331359
332360
333361class EMAVectorQuantizer (nn .Module ):
334362 def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
335363 remap = None , unknown_index = "random" ):
336364 super ().__init__ ()
337- self .embedding_dim = embedding_dim
338- self .n_embed = n_embed
339- self .decay = decay
340- self .eps = eps
365+ self .codebook_dim = codebook_dim
366+ self .num_tokens = num_tokens
341367 self .beta = beta
342- self .embedding = nn .Embedding (self .n_embed , self .embedding_dim )
343- self .embedding .weight .requires_grad = False
344- self .cluster_size = nn .Parameter (torch .zeros (n_embed ),requires_grad = False )
345- self .embed_avg = nn .Parameter (torch .randn (self .n_embed , self .embedding_dim ),requires_grad = False )
368+ self .embedding = EmbeddingEMA (self .num_tokens , self .codebook_dim , decay , eps )
346369
347370 self .remap = remap
348371 if self .remap is not None :
@@ -384,37 +407,31 @@ def unmap_to_all(self, inds):
384407 def forward (self , z ):
385408 # reshape z -> (batch, height, width, channel) and flatten
386409 #z, 'b c h w -> b h w c'
387- z = z .permute (0 , 2 , 3 , 1 ).contiguous ()
388- z_flattened = z .view (- 1 , self .embedding_dim )
410+ z = rearrange (z , 'b c h w -> b h w c' )
411+ z_flattened = z .reshape (- 1 , self .codebook_dim )
412+
389413 # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414+ d = z_flattened .pow (2 ).sum (dim = 1 , keepdim = True ) + \
415+ self .embedding .weight .pow (2 ).sum (dim = 1 ) - 2 * \
416+ torch .einsum ('bd,nd->bn' , z_flattened , self .embedding .weight ) # 'n d -> d n'
390417
391- d = torch .sum (z_flattened .pow (2 ), dim = 1 , keepdim = True ) + \
392- torch .sum (self .embedding .weight .pow (2 ), dim = 1 ) - 2 * \
393- torch .einsum ('bd,dn->bn' , z_flattened , self .embedding .weight .permute (1 ,0 )) # 'n d -> d n'
394418
395419 encoding_indices = torch .argmin (d , dim = 1 )
420+
396421 z_q = self .embedding (encoding_indices ).view (z .shape )
397- encodings = F .one_hot (encoding_indices , self .n_embed ).type (z .dtype )
422+ encodings = F .one_hot (encoding_indices , self .num_tokens ).type (z .dtype )
398423 avg_probs = torch .mean (encodings , dim = 0 )
399424 perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
400425
401- if self .training :
402- encodings_sum = encodings .sum (0 )
426+ if self .training and self .embedding .update :
403427 #EMA cluster size
404- self .cluster_size .mul_ (self .decay ).add_ (encodings_sum , alpha = 1 - self .decay )
405-
406- embed_sum = torch .matmul (encodings .t (), z_flattened )
428+ encodings_sum = encodings .sum (0 )
429+ self .embedding .cluster_size_ema_update (encodings_sum )
407430 #EMA embedding average
408- self .embed_avg .mul_ (self .decay ).add_ (embed_sum , alpha = 1 - self .decay )
409-
410- #cluster size Laplace smoothing
411- n = self .cluster_size .sum ()
412- cluster_size = (
413- (self .cluster_size + self .eps ) / (n + self .n_embed * self .eps ) * n
414- )
415- #normalize embedding average with smoothed cluster size
416- embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
417- self .embedding .weight .data .copy_ (embed_normalized .data )
431+ embed_sum = encodings .transpose (0 ,1 ) @ z_flattened
432+ self .embedding .embed_avg_ema_update (embed_sum )
433+ #normalize embed_avg and update weight
434+ self .embedding .weight_update (self .num_tokens )
418435
419436 # compute loss for embedding
420437 loss = self .beta * F .mse_loss (z_q .detach (), z )
@@ -424,68 +441,5 @@ def forward(self, z):
424441
425442 # reshape back to match original input shape
426443 #z_q, 'b h w c -> b c h w'
427- z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
428- return z_q , loss , (perplexity , encodings , encoding_indices )
429-
430-
431-
432- #Original Sonnet version of EMAVectorQuantizer
433- class EmbeddingEMA (nn .Module ):
434- def __init__ (self , n_embed , embedding_dim ):
435- super ().__init__ ()
436- weight = torch .randn (embedding_dim , n_embed )
437- self .register_buffer ("weight" , weight )
438- self .register_buffer ("cluster_size" , torch .zeros (n_embed ))
439- self .register_buffer ("embed_avg" , weight .clone ())
440-
441- def forward (self , embed_id ):
442- return F .embedding (embed_id , self .weight .transpose (0 , 1 ))
443-
444-
445- class SonnetEMAVectorQuantizer (nn .Module ):
446- def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
447- remap = None , unknown_index = "random" ):
448- super ().__init__ ()
449- self .embedding_dim = embedding_dim
450- self .n_embed = n_embed
451- self .decay = decay
452- self .eps = eps
453- self .beta = beta
454- self .embedding = EmbeddingEMA (n_embed ,embedding_dim )
455-
456- def forward (self , z ):
457- z = z .permute (0 , 2 , 3 , 1 ).contiguous ()
458- z_flattened = z .reshape (- 1 , self .embedding_dim )
459- d = (
460- z_flattened .pow (2 ).sum (1 , keepdim = True )
461- - 2 * z_flattened @ self .embedding .weight
462- + self .embedding .weight .pow (2 ).sum (0 , keepdim = True )
463- )
464- _ , encoding_indices = (- d ).max (1 )
465- encodings = F .one_hot (encoding_indices , self .n_embed ).type (z_flattened .dtype )
466- encoding_indices = encoding_indices .view (* z .shape [:- 1 ])
467- z_q = self .embedding (encoding_indices )
468- avg_probs = torch .mean (encodings , dim = 0 )
469- perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
470-
471- if self .training :
472- encodings_sum = encodings .sum (0 )
473- embed_sum = z_flattened .transpose (0 , 1 ) @ encodings
474- #EMA cluster size
475- self .embedding .cluster_size .data .mul_ (self .decay ).add_ (encodings_sum , alpha = 1 - self .decay )
476- #EMA embedding average
477- self .embedding .embed_avg .data .mul_ (self .decay ).add_ (embed_sum , alpha = 1 - self .decay )
478-
479- #cluster size Laplace smoothing
480- n = self .embedding .cluster_size .sum ()
481- cluster_size = (
482- (self .embedding .cluster_size + self .eps ) / (n + self .n_embed * self .eps ) * n
483- )
484- #normalize embedding average with smoothed cluster size
485- embed_normalized = self .embedding .embed_avg / cluster_size .unsqueeze (0 )
486- self .embedding .weight .data .copy_ (embed_normalized )
487-
488- loss = self .beta * (z_q .detach () - z ).pow (2 ).mean ()
489- z_q = z + (z_q - z ).detach ()
490- z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
444+ z_q = rearrange (z_q , 'b h w c -> b c h w' )
491445 return z_q , loss , (perplexity , encodings , encoding_indices )
0 commit comments