1111from comfy .ldm .modules .diffusionmodules .mmdit import TimestepEmbedder
1212from comfy .ldm .modules .attention import optimized_attention_masked
1313from comfy .ldm .flux .layers import EmbedND
14+ from comfy .ldm .flux .math import apply_rope
1415import comfy .patcher_extension
1516
1617
@@ -31,6 +32,7 @@ def __init__(
3132 n_heads : int ,
3233 n_kv_heads : Optional [int ],
3334 qk_norm : bool ,
35+ out_bias : bool = False ,
3436 operation_settings = {},
3537 ):
3638 """
@@ -59,7 +61,7 @@ def __init__(
5961 self .out = operation_settings .get ("operations" ).Linear (
6062 n_heads * self .head_dim ,
6163 dim ,
62- bias = False ,
64+ bias = out_bias ,
6365 device = operation_settings .get ("device" ),
6466 dtype = operation_settings .get ("dtype" ),
6567 )
@@ -70,35 +72,6 @@ def __init__(
7072 else :
7173 self .q_norm = self .k_norm = nn .Identity ()
7274
73- @staticmethod
74- def apply_rotary_emb (
75- x_in : torch .Tensor ,
76- freqs_cis : torch .Tensor ,
77- ) -> torch .Tensor :
78- """
79- Apply rotary embeddings to input tensors using the given frequency
80- tensor.
81-
82- This function applies rotary embeddings to the given query 'xq' and
83- key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
84- input tensors are reshaped as complex numbers, and the frequency tensor
85- is reshaped for broadcasting compatibility. The resulting tensors
86- contain rotary embeddings and are returned as real tensors.
87-
88- Args:
89- x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
90- freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
91- exponentials.
92-
93- Returns:
94- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
95- and key tensor with rotary embeddings.
96- """
97-
98- t_ = x_in .reshape (* x_in .shape [:- 1 ], - 1 , 1 , 2 )
99- t_out = freqs_cis [..., 0 ] * t_ [..., 0 ] + freqs_cis [..., 1 ] * t_ [..., 1 ]
100- return t_out .reshape (* x_in .shape )
101-
10275 def forward (
10376 self ,
10477 x : torch .Tensor ,
@@ -134,8 +107,7 @@ def forward(
134107 xq = self .q_norm (xq )
135108 xk = self .k_norm (xk )
136109
137- xq = JointAttention .apply_rotary_emb (xq , freqs_cis = freqs_cis )
138- xk = JointAttention .apply_rotary_emb (xk , freqs_cis = freqs_cis )
110+ xq , xk = apply_rope (xq , xk , freqs_cis )
139111
140112 n_rep = self .n_local_heads // self .n_local_kv_heads
141113 if n_rep >= 1 :
@@ -215,6 +187,8 @@ def __init__(
215187 norm_eps : float ,
216188 qk_norm : bool ,
217189 modulation = True ,
190+ z_image_modulation = False ,
191+ attn_out_bias = False ,
218192 operation_settings = {},
219193 ) -> None :
220194 """
@@ -235,10 +209,10 @@ def __init__(
235209 super ().__init__ ()
236210 self .dim = dim
237211 self .head_dim = dim // n_heads
238- self .attention = JointAttention (dim , n_heads , n_kv_heads , qk_norm , operation_settings = operation_settings )
212+ self .attention = JointAttention (dim , n_heads , n_kv_heads , qk_norm , out_bias = attn_out_bias , operation_settings = operation_settings )
239213 self .feed_forward = FeedForward (
240214 dim = dim ,
241- hidden_dim = 4 * dim ,
215+ hidden_dim = dim ,
242216 multiple_of = multiple_of ,
243217 ffn_dim_multiplier = ffn_dim_multiplier ,
244218 operation_settings = operation_settings ,
@@ -252,16 +226,27 @@ def __init__(
252226
253227 self .modulation = modulation
254228 if modulation :
255- self .adaLN_modulation = nn .Sequential (
256- nn .SiLU (),
257- operation_settings .get ("operations" ).Linear (
258- min (dim , 1024 ),
259- 4 * dim ,
260- bias = True ,
261- device = operation_settings .get ("device" ),
262- dtype = operation_settings .get ("dtype" ),
263- ),
264- )
229+ if z_image_modulation :
230+ self .adaLN_modulation = nn .Sequential (
231+ operation_settings .get ("operations" ).Linear (
232+ min (dim , 256 ),
233+ 4 * dim ,
234+ bias = True ,
235+ device = operation_settings .get ("device" ),
236+ dtype = operation_settings .get ("dtype" ),
237+ ),
238+ )
239+ else :
240+ self .adaLN_modulation = nn .Sequential (
241+ nn .SiLU (),
242+ operation_settings .get ("operations" ).Linear (
243+ min (dim , 1024 ),
244+ 4 * dim ,
245+ bias = True ,
246+ device = operation_settings .get ("device" ),
247+ dtype = operation_settings .get ("dtype" ),
248+ ),
249+ )
265250
266251 def forward (
267252 self ,
@@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
323308 The final layer of NextDiT.
324309 """
325310
326- def __init__ (self , hidden_size , patch_size , out_channels , operation_settings = {}):
311+ def __init__ (self , hidden_size , patch_size , out_channels , z_image_modulation = False , operation_settings = {}):
327312 super ().__init__ ()
328313 self .norm_final = operation_settings .get ("operations" ).LayerNorm (
329314 hidden_size ,
@@ -340,10 +325,15 @@ def __init__(self, hidden_size, patch_size, out_channels, operation_settings={})
340325 dtype = operation_settings .get ("dtype" ),
341326 )
342327
328+ if z_image_modulation :
329+ min_mod = 256
330+ else :
331+ min_mod = 1024
332+
343333 self .adaLN_modulation = nn .Sequential (
344334 nn .SiLU (),
345335 operation_settings .get ("operations" ).Linear (
346- min (hidden_size , 1024 ),
336+ min (hidden_size , min_mod ),
347337 hidden_size ,
348338 bias = True ,
349339 device = operation_settings .get ("device" ),
@@ -373,12 +363,16 @@ def __init__(
373363 n_heads : int = 32 ,
374364 n_kv_heads : Optional [int ] = None ,
375365 multiple_of : int = 256 ,
376- ffn_dim_multiplier : Optional [ float ] = None ,
366+ ffn_dim_multiplier : float = 4.0 ,
377367 norm_eps : float = 1e-5 ,
378368 qk_norm : bool = False ,
379369 cap_feat_dim : int = 5120 ,
380370 axes_dims : List [int ] = (16 , 56 , 56 ),
381371 axes_lens : List [int ] = (1 , 512 , 512 ),
372+ rope_theta = 10000.0 ,
373+ z_image_modulation = False ,
374+ time_scale = 1.0 ,
375+ pad_tokens_multiple = None ,
382376 image_model = None ,
383377 device = None ,
384378 dtype = None ,
@@ -390,6 +384,8 @@ def __init__(
390384 self .in_channels = in_channels
391385 self .out_channels = in_channels
392386 self .patch_size = patch_size
387+ self .time_scale = time_scale
388+ self .pad_tokens_multiple = pad_tokens_multiple
393389
394390 self .x_embedder = operation_settings .get ("operations" ).Linear (
395391 in_features = patch_size * patch_size * in_channels ,
@@ -411,6 +407,7 @@ def __init__(
411407 norm_eps ,
412408 qk_norm ,
413409 modulation = True ,
410+ z_image_modulation = z_image_modulation ,
414411 operation_settings = operation_settings ,
415412 )
416413 for layer_id in range (n_refiner_layers )
@@ -434,7 +431,7 @@ def __init__(
434431 ]
435432 )
436433
437- self .t_embedder = TimestepEmbedder (min (dim , 1024 ), ** operation_settings )
434+ self .t_embedder = TimestepEmbedder (min (dim , 1024 ), output_size = 256 if z_image_modulation else None , ** operation_settings )
438435 self .cap_embedder = nn .Sequential (
439436 operation_settings .get ("operations" ).RMSNorm (cap_feat_dim , eps = norm_eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )),
440437 operation_settings .get ("operations" ).Linear (
@@ -457,18 +454,24 @@ def __init__(
457454 ffn_dim_multiplier ,
458455 norm_eps ,
459456 qk_norm ,
457+ z_image_modulation = z_image_modulation ,
458+ attn_out_bias = False ,
460459 operation_settings = operation_settings ,
461460 )
462461 for layer_id in range (n_layers )
463462 ]
464463 )
465464 self .norm_final = operation_settings .get ("operations" ).RMSNorm (dim , eps = norm_eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
466- self .final_layer = FinalLayer (dim , patch_size , self .out_channels , operation_settings = operation_settings )
465+ self .final_layer = FinalLayer (dim , patch_size , self .out_channels , z_image_modulation = z_image_modulation , operation_settings = operation_settings )
466+
467+ if self .pad_tokens_multiple is not None :
468+ self .x_pad_token = nn .Parameter (torch .empty ((1 , dim ), device = device , dtype = dtype ))
469+ self .cap_pad_token = nn .Parameter (torch .empty ((1 , dim ), device = device , dtype = dtype ))
467470
468471 assert (dim // n_heads ) == sum (axes_dims )
469472 self .axes_dims = axes_dims
470473 self .axes_lens = axes_lens
471- self .rope_embedder = EmbedND (dim = dim // n_heads , theta = 10000.0 , axes_dim = axes_dims )
474+ self .rope_embedder = EmbedND (dim = dim // n_heads , theta = rope_theta , axes_dim = axes_dims )
472475 self .dim = dim
473476 self .n_heads = n_heads
474477
@@ -503,108 +506,42 @@ def patchify_and_embed(
503506 bsz = len (x )
504507 pH = pW = self .patch_size
505508 device = x [0 ].device
506- dtype = x [0 ].dtype
507509
508- if cap_mask is not None :
509- l_effective_cap_len = cap_mask .sum (dim = 1 ).tolist ()
510- else :
511- l_effective_cap_len = [num_tokens ] * bsz
510+ if self .pad_tokens_multiple is not None :
511+ pad_extra = (- cap_feats .shape [1 ]) % self .pad_tokens_multiple
512+ cap_feats = torch .cat ((cap_feats , self .cap_pad_token .to (device = cap_feats .device , dtype = cap_feats .dtype ).unsqueeze (0 ).repeat (cap_feats .shape [0 ], pad_extra , 1 )), dim = 1 )
512513
513- if cap_mask is not None and not torch .is_floating_point ( cap_mask ):
514- cap_mask = ( cap_mask - 1 ). to ( dtype ) * torch . finfo ( dtype ). max
514+ cap_pos_ids = torch . zeros ( bsz , cap_feats . shape [ 1 ], 3 , dtype = torch .float32 , device = device )
515+ cap_pos_ids [:, :, 0 ] = torch . arange ( cap_feats . shape [ 1 ], dtype = torch . float32 , device = device ) + 1.0
515516
516- img_sizes = [( img . size ( 1 ), img . size ( 2 )) for img in x ]
517- l_effective_img_len = [( H // pH ) * ( W // pW ) for ( H , W ) in img_sizes ]
517+ B , C , H , W = x . shape
518+ x = self . x_embedder ( x . view ( B , C , H // pH , pH , W // pW , pW ). permute ( 0 , 2 , 4 , 3 , 5 , 1 ). flatten ( 3 ). flatten ( 1 , 2 ))
518519
519- max_seq_len = max (
520- (cap_len + img_len for cap_len , img_len in zip (l_effective_cap_len , l_effective_img_len ))
521- )
522- max_cap_len = max (l_effective_cap_len )
523- max_img_len = max (l_effective_img_len )
524-
525- position_ids = torch .zeros (bsz , max_seq_len , 3 , dtype = torch .float32 , device = device )
526-
527- for i in range (bsz ):
528- cap_len = l_effective_cap_len [i ]
529- img_len = l_effective_img_len [i ]
530- H , W = img_sizes [i ]
531- H_tokens , W_tokens = H // pH , W // pW
532- assert H_tokens * W_tokens == img_len
533-
534- rope_options = transformer_options .get ("rope_options" , None )
535- h_scale = 1.0
536- w_scale = 1.0
537- h_start = 0
538- w_start = 0
539- if rope_options is not None :
540- h_scale = rope_options .get ("scale_y" , 1.0 )
541- w_scale = rope_options .get ("scale_x" , 1.0 )
542-
543- h_start = rope_options .get ("shift_y" , 0.0 )
544- w_start = rope_options .get ("shift_x" , 0.0 )
545-
546- position_ids [i , :cap_len , 0 ] = torch .arange (cap_len , dtype = torch .float32 , device = device )
547- position_ids [i , cap_len :cap_len + img_len , 0 ] = cap_len
548- row_ids = (torch .arange (H_tokens , dtype = torch .float32 , device = device ) * h_scale + h_start ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
549- col_ids = (torch .arange (W_tokens , dtype = torch .float32 , device = device ) * w_scale + w_start ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
550- position_ids [i , cap_len :cap_len + img_len , 1 ] = row_ids
551- position_ids [i , cap_len :cap_len + img_len , 2 ] = col_ids
552-
553- freqs_cis = self .rope_embedder (position_ids ).movedim (1 , 2 ).to (dtype )
554-
555- # build freqs_cis for cap and image individually
556- cap_freqs_cis_shape = list (freqs_cis .shape )
557- # cap_freqs_cis_shape[1] = max_cap_len
558- cap_freqs_cis_shape [1 ] = cap_feats .shape [1 ]
559- cap_freqs_cis = torch .zeros (* cap_freqs_cis_shape , device = device , dtype = freqs_cis .dtype )
560-
561- img_freqs_cis_shape = list (freqs_cis .shape )
562- img_freqs_cis_shape [1 ] = max_img_len
563- img_freqs_cis = torch .zeros (* img_freqs_cis_shape , device = device , dtype = freqs_cis .dtype )
564-
565- for i in range (bsz ):
566- cap_len = l_effective_cap_len [i ]
567- img_len = l_effective_img_len [i ]
568- cap_freqs_cis [i , :cap_len ] = freqs_cis [i , :cap_len ]
569- img_freqs_cis [i , :img_len ] = freqs_cis [i , cap_len :cap_len + img_len ]
520+ H_tokens , W_tokens = H // pH , W // pW
521+ x_pos_ids = torch .zeros ((bsz , x .shape [1 ], 3 ), dtype = torch .float32 , device = device )
522+ x_pos_ids [:, :, 0 ] = cap_feats .shape [1 ] + 1
523+ x_pos_ids [:, :, 1 ] = torch .arange (H_tokens , dtype = torch .float32 , device = device ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
524+ x_pos_ids [:, :, 2 ] = torch .arange (W_tokens , dtype = torch .float32 , device = device ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
570525
571- # refine context
572- for layer in self .context_refiner :
573- cap_feats = layer (cap_feats , cap_mask , cap_freqs_cis , transformer_options = transformer_options )
574-
575- # refine image
576- flat_x = []
577- for i in range (bsz ):
578- img = x [i ]
579- C , H , W = img .size ()
580- img = img .view (C , H // pH , pH , W // pW , pW ).permute (1 , 3 , 2 , 4 , 0 ).flatten (2 ).flatten (0 , 1 )
581- flat_x .append (img )
582- x = flat_x
583- padded_img_embed = torch .zeros (bsz , max_img_len , x [0 ].shape [- 1 ], device = device , dtype = x [0 ].dtype )
584- padded_img_mask = torch .zeros (bsz , max_img_len , dtype = dtype , device = device )
585- for i in range (bsz ):
586- padded_img_embed [i , :l_effective_img_len [i ]] = x [i ]
587- padded_img_mask [i , l_effective_img_len [i ]:] = - torch .finfo (dtype ).max
588-
589- padded_img_embed = self .x_embedder (padded_img_embed )
590- padded_img_mask = padded_img_mask .unsqueeze (1 )
591- for layer in self .noise_refiner :
592- padded_img_embed = layer (padded_img_embed , padded_img_mask , img_freqs_cis , t , transformer_options = transformer_options )
526+ if self .pad_tokens_multiple is not None :
527+ pad_extra = (- x .shape [1 ]) % self .pad_tokens_multiple
528+ x = torch .cat ((x , self .x_pad_token .to (device = x .device , dtype = x .dtype ).unsqueeze (0 ).repeat (x .shape [0 ], pad_extra , 1 )), dim = 1 )
529+ x_pos_ids = torch .nn .functional .pad (x_pos_ids , (0 , 0 , 0 , pad_extra ))
593530
594- if cap_mask is not None :
595- mask = torch .zeros (bsz , max_seq_len , dtype = dtype , device = device )
596- mask [:, :max_cap_len ] = cap_mask [:, :max_cap_len ]
597- else :
598- mask = None
531+ freqs_cis = self .rope_embedder (torch .cat ((cap_pos_ids , x_pos_ids ), dim = 1 )).movedim (1 , 2 )
599532
600- padded_full_embed = torch .zeros (bsz , max_seq_len , self .dim , device = device , dtype = x [0 ].dtype )
601- for i in range (bsz ):
602- cap_len = l_effective_cap_len [i ]
603- img_len = l_effective_img_len [i ]
533+ # refine context
534+ for layer in self .context_refiner :
535+ cap_feats = layer (cap_feats , cap_mask , freqs_cis [:, :cap_pos_ids .shape [1 ]], transformer_options = transformer_options )
604536
605- padded_full_embed [i , :cap_len ] = cap_feats [i , :cap_len ]
606- padded_full_embed [i , cap_len :cap_len + img_len ] = padded_img_embed [i , :img_len ]
537+ padded_img_mask = None
538+ for layer in self .noise_refiner :
539+ x = layer (x , padded_img_mask , freqs_cis [:, cap_pos_ids .shape [1 ]:], t , transformer_options = transformer_options )
607540
541+ padded_full_embed = torch .cat ((cap_feats , x ), dim = 1 )
542+ mask = None
543+ img_sizes = [(H , W )] * bsz
544+ l_effective_cap_len = [cap_feats .shape [1 ]] * bsz
608545 return padded_full_embed , mask , img_sizes , l_effective_cap_len , freqs_cis
609546
610547 def forward (self , x , timesteps , context , num_tokens , attention_mask = None , ** kwargs ):
@@ -627,7 +564,7 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwa
627564 y: (N,) tensor of text tokens/features
628565 """
629566
630- t = self .t_embedder (t , dtype = x .dtype ) # (N, D)
567+ t = self .t_embedder (t * self . time_scale , dtype = x .dtype ) # (N, D)
631568 adaln_input = t
632569
633570 cap_feats = self .cap_embedder (cap_feats ) # (N, L, D) # todo check if able to batchify w.o. redundant compute
0 commit comments