@@ -327,3 +327,165 @@ def get_codebook_entry(self, indices, shape):
327327 z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
328328
329329 return z_q
330+
331+
332+
333+ class EMAVectorQuantizer (nn .Module ):
334+ def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
335+ remap = None , unknown_index = "random" ):
336+ super ().__init__ ()
337+ self .embedding_dim = embedding_dim
338+ self .n_embed = n_embed
339+ self .decay = decay
340+ self .eps = eps
341+ self .beta = beta
342+ self .embedding = nn .Embedding (self .n_embed , self .embedding_dim )
343+ self .embedding .weight .requires_grad = False
344+ self .cluster_size = nn .Parameter (torch .zeros (n_embed ),requires_grad = False )
345+ self .embed_avg = nn .Parameter (torch .Tensor (self .n_embed , self .embedding_dim ),requires_grad = False )
346+ self .embed_avg .data .copy_ (self .embedding .weight .data )
347+ self .remap = remap
348+ if self .remap is not None :
349+ self .register_buffer ("used" , torch .tensor (np .load (self .remap )))
350+ self .re_embed = self .used .shape [0 ]
351+ self .unknown_index = unknown_index # "random" or "extra" or integer
352+ if self .unknown_index == "extra" :
353+ self .unknown_index = self .re_embed
354+ self .re_embed = self .re_embed + 1
355+ print (f"Remapping { self .n_embed } indices to { self .re_embed } indices. "
356+ f"Using { self .unknown_index } for unknown indices." )
357+ else :
358+ self .re_embed = n_embed
359+
360+ def remap_to_used (self , inds ):
361+ ishape = inds .shape
362+ assert len (ishape )> 1
363+ inds = inds .reshape (ishape [0 ],- 1 )
364+ used = self .used .to (inds )
365+ match = (inds [:,:,None ]== used [None ,None ,...]).long ()
366+ new = match .argmax (- 1 )
367+ unknown = match .sum (2 )< 1
368+ if self .unknown_index == "random" :
369+ new [unknown ]= torch .randint (0 ,self .re_embed ,size = new [unknown ].shape ).to (device = new .device )
370+ else :
371+ new [unknown ] = self .unknown_index
372+ return new .reshape (ishape )
373+
374+ def unmap_to_all (self , inds ):
375+ ishape = inds .shape
376+ assert len (ishape )> 1
377+ inds = inds .reshape (ishape [0 ],- 1 )
378+ used = self .used .to (inds )
379+ if self .re_embed > self .used .shape [0 ]: # extra token
380+ inds [inds >= self .used .shape [0 ]] = 0 # simply set to zero
381+ back = torch .gather (used [None ,:][inds .shape [0 ]* [0 ],:], 1 , inds )
382+ return back .reshape (ishape )
383+
384+ def forward (self , z ):
385+ # reshape z -> (batch, height, width, channel) and flatten
386+ #z, 'b c h w -> b h w c'
387+ z = z .permute (0 , 2 , 3 , 1 ).contiguous ()
388+ z_flattened = z .view (- 1 , self .embedding_dim )
389+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
390+
391+ d = torch .sum (z_flattened .pow (2 ), dim = 1 , keepdim = True ) + \
392+ torch .sum (self .embedding .weight .pow (2 ), dim = 1 ) - 2 * \
393+ torch .einsum ('bd,dn->bn' , z_flattened , self .embedding .weight .permute (1 ,0 )) # 'n d -> d n'
394+
395+ encoding_indices = torch .argmin (d , dim = 1 )
396+ z_q = self .embedding (encoding_indices ).view (z .shape )
397+ encodings = F .one_hot (encoding_indices , self .n_embed ).type (z .dtype )
398+ avg_probs = torch .mean (encodings , dim = 0 )
399+ perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
400+
401+ if self .training :
402+ encodings_sum = encodings .sum (0 )
403+ #EMA cluster size
404+ self .cluster_size .mul_ (self .decay ).add_ (encodings_sum , alpha = 1 - self .decay )
405+
406+ embed_sum = torch .matmul (encodings .t (), z_flattened )
407+ #EMA embedding average
408+ self .embed_avg .mul_ (self .decay ).add_ (embed_sum , alpha = 1 - self .decay )
409+
410+ #cluster size Laplace smoothing
411+ n = self .cluster_size .sum ()
412+ cluster_size = (
413+ (self .cluster_size + self .eps ) / (n + self .n_embed * self .eps ) * n
414+ )
415+ #normalize embedding average with smoothed cluster size
416+ embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
417+ self .embedding .weight .data .copy_ (embed_normalized .data )
418+
419+ # compute loss for embedding
420+ loss = self .beta * F .mse_loss (z_q .detach (), z )
421+
422+ # preserve gradients
423+ z_q = z + (z_q - z ).detach ()
424+
425+ # reshape back to match original input shape
426+ #z_q, 'b h w c -> b c h w'
427+ z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
428+ return z_q , loss , (perplexity , encodings , encoding_indices )
429+
430+
431+
432+ #Original Sonnet version of EMAVectorQuantizer
433+ class EmbeddingEMA (nn .Module ):
434+ def __init__ (self , n_embed , embedding_dim ):
435+ super ().__init__ ()
436+ weight = torch .randn (embedding_dim , n_embed )
437+ self .register_buffer ("weight" , weight )
438+ self .register_buffer ("cluster_size" , torch .zeros (n_embed ))
439+ self .register_buffer ("embed_avg" , weight .clone ())
440+
441+ def forward (self , embed_id ):
442+ return F .embedding (embed_id , self .weight .transpose (0 , 1 ))
443+
444+
445+ class SonnetEMAVectorQuantizer (nn .Module ):
446+ def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
447+ remap = None , unknown_index = "random" ):
448+ super ().__init__ ()
449+ self .embedding_dim = embedding_dim
450+ self .n_embed = n_embed
451+ self .decay = decay
452+ self .eps = eps
453+ self .beta = beta
454+ self .embedding = EmbeddingEMA (n_embed ,embedding_dim )
455+
456+ def forward (self , z ):
457+ z = z .permute (0 , 2 , 3 , 1 ).contiguous ()
458+ z_flattened = z .reshape (- 1 , self .embedding_dim )
459+ d = (
460+ z_flattened .pow (2 ).sum (1 , keepdim = True )
461+ - 2 * z_flattened @ self .embedding .weight
462+ + self .embedding .weight .pow (2 ).sum (0 , keepdim = True )
463+ )
464+ _ , encoding_indices = (- d ).max (1 )
465+ encodings = F .one_hot (encoding_indices , self .n_embed ).type (z_flattened .dtype )
466+ encoding_indices = encoding_indices .view (* z .shape [:- 1 ])
467+ z_q = self .embedding (encoding_indices )
468+ avg_probs = torch .mean (encodings , dim = 0 )
469+ perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
470+
471+ if self .training :
472+ encodings_sum = encodings .sum (0 )
473+ embed_sum = z_flattened .transpose (0 , 1 ) @ encodings
474+ #EMA cluster size
475+ self .embedding .cluster_size .data .mul_ (self .decay ).add_ (encodings_sum , alpha = 1 - self .decay )
476+ #EMA embedding average
477+ self .embedding .embed_avg .data .mul_ (self .decay ).add_ (embed_sum , alpha = 1 - self .decay )
478+
479+ #cluster size Laplace smoothing
480+ n = self .embedding .cluster_size .sum ()
481+ cluster_size = (
482+ (self .embedding .cluster_size + self .eps ) / (n + self .n_embed * self .eps ) * n
483+ )
484+ #normalize embedding average with smoothed cluster size
485+ embed_normalized = self .embedding .embed_avg / cluster_size .unsqueeze (0 )
486+ self .embedding .weight .data .copy_ (embed_normalized )
487+
488+ loss = self .beta * (z_q .detach () - z ).pow (2 ).mean ()
489+ z_q = z + (z_q - z ).detach ()
490+ z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
491+ return z_q , loss , (perplexity , encodings , encoding_indices )
0 commit comments