2222from .activations import GEGLU , GELU , ApproximateGELU , FP32SiLU , SwiGLU
2323from .attention_processor import Attention , JointAttnProcessor2_0
2424from .embeddings import SinusoidalPositionalEmbedding
25- from .normalization import AdaLayerNorm , AdaLayerNormContinuous , AdaLayerNormZero , RMSNorm , SD35AdaLayerNormZeroX
25+ from .normalization import (
26+ AdaLayerNorm ,
27+ AdaLayerNormContinuous ,
28+ AdaLayerNormZero ,
29+ RMSNorm ,
30+ SD35AdaLayerNormZeroX ,
31+ )
2632
2733
2834logger = logging .get_logger (__name__ )
@@ -122,7 +128,12 @@ def __init__(
122128
123129 if context_norm_type == "ada_norm_continous" :
124130 self .norm1_context = AdaLayerNormContinuous (
125- dim , dim , elementwise_affine = False , eps = 1e-6 , bias = True , norm_type = "layer_norm"
131+ dim ,
132+ dim ,
133+ elementwise_affine = False ,
134+ eps = 1e-6 ,
135+ bias = True ,
136+ norm_type = "layer_norm" ,
126137 )
127138 elif context_norm_type == "ada_norm_zero" :
128139 self .norm1_context = AdaLayerNormZero (dim )
@@ -188,33 +199,51 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188199 self ._chunk_dim = dim
189200
190201 def forward (
191- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor
202+ self ,
203+ hidden_states : torch .FloatTensor ,
204+ encoder_hidden_states : torch .FloatTensor ,
205+ temb : torch .FloatTensor ,
206+ joint_attention_kwargs : Dict [str , Any ] = None ,
192207 ):
193208 if self .use_dual_attention :
194- norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp , norm_hidden_states2 , gate_msa2 = self .norm1 (
195- hidden_states , emb = temb
196- )
209+ (
210+ norm_hidden_states ,
211+ gate_msa ,
212+ shift_mlp ,
213+ scale_mlp ,
214+ gate_mlp ,
215+ norm_hidden_states2 ,
216+ gate_msa2 ,
217+ ) = self .norm1 (hidden_states , emb = temb )
197218 else :
198219 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
199220
200221 if self .context_pre_only :
201222 norm_encoder_hidden_states = self .norm1_context (encoder_hidden_states , temb )
202223 else :
203- norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
204- encoder_hidden_states , emb = temb
205- )
224+ (
225+ norm_encoder_hidden_states ,
226+ c_gate_msa ,
227+ c_shift_mlp ,
228+ c_scale_mlp ,
229+ c_gate_mlp ,
230+ ) = self .norm1_context (encoder_hidden_states , emb = temb )
231+
232+ joint_attention_kwargs = joint_attention_kwargs .copy () if joint_attention_kwargs is not None else {}
206233
207234 # Attention.
208235 attn_output , context_attn_output = self .attn (
209- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
236+ hidden_states = norm_hidden_states ,
237+ encoder_hidden_states = norm_encoder_hidden_states ,
238+ ** joint_attention_kwargs ,
210239 )
211240
212241 # Process attention outputs for the `hidden_states`.
213242 attn_output = gate_msa .unsqueeze (1 ) * attn_output
214243 hidden_states = hidden_states + attn_output
215244
216245 if self .use_dual_attention :
217- attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 )
246+ attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 , ** joint_attention_kwargs )
218247 attn_output2 = gate_msa2 .unsqueeze (1 ) * attn_output2
219248 hidden_states = hidden_states + attn_output2
220249
@@ -241,7 +270,10 @@ def forward(
241270 if self ._chunk_size is not None :
242271 # "feed_forward_chunk_size" can be used to save memory
243272 context_ff_output = _chunked_feed_forward (
244- self .ff_context , norm_encoder_hidden_states , self ._chunk_dim , self ._chunk_size
273+ self .ff_context ,
274+ norm_encoder_hidden_states ,
275+ self ._chunk_dim ,
276+ self ._chunk_size ,
245277 )
246278 else :
247279 context_ff_output = self .ff_context (norm_encoder_hidden_states )
@@ -402,7 +434,7 @@ def __init__(
402434
403435 self .attn2 = Attention (
404436 query_dim = dim ,
405- cross_attention_dim = cross_attention_dim if not double_self_attention else None ,
437+ cross_attention_dim = ( cross_attention_dim if not double_self_attention else None ) ,
406438 heads = num_attention_heads ,
407439 dim_head = attention_head_dim ,
408440 dropout = dropout ,
@@ -506,7 +538,7 @@ def forward(
506538
507539 attn_output = self .attn1 (
508540 norm_hidden_states ,
509- encoder_hidden_states = encoder_hidden_states if self .only_cross_attention else None ,
541+ encoder_hidden_states = ( encoder_hidden_states if self .only_cross_attention else None ) ,
510542 attention_mask = attention_mask ,
511543 ** cross_attention_kwargs ,
512544 )
@@ -979,7 +1011,7 @@ def __init__(
9791011
9801012 self .attn2 = Attention (
9811013 query_dim = dim ,
982- cross_attention_dim = cross_attention_dim if not double_self_attention else None ,
1014+ cross_attention_dim = ( cross_attention_dim if not double_self_attention else None ) ,
9831015 heads = num_attention_heads ,
9841016 dim_head = attention_head_dim ,
9851017 dropout = dropout ,
@@ -1045,7 +1077,10 @@ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid")
10451077 return weights
10461078
10471079 def set_free_noise_properties (
1048- self , context_length : int , context_stride : int , weighting_scheme : str = "pyramid"
1080+ self ,
1081+ context_length : int ,
1082+ context_stride : int ,
1083+ weighting_scheme : str = "pyramid" ,
10491084 ) -> None :
10501085 self .context_length = context_length
10511086 self .context_stride = context_stride
@@ -1112,7 +1147,7 @@ def forward(
11121147
11131148 attn_output = self .attn1 (
11141149 norm_hidden_states ,
1115- encoder_hidden_states = encoder_hidden_states if self .only_cross_attention else None ,
1150+ encoder_hidden_states = ( encoder_hidden_states if self .only_cross_attention else None ) ,
11161151 attention_mask = attention_mask ,
11171152 ** cross_attention_kwargs ,
11181153 )
@@ -1158,7 +1193,11 @@ def forward(
11581193 # looked into this deeply because other memory optimizations led to more pronounced reductions.
11591194 hidden_states = torch .cat (
11601195 [
1161- torch .where (num_times_split > 0 , accumulated_split / num_times_split , accumulated_split )
1196+ torch .where (
1197+ num_times_split > 0 ,
1198+ accumulated_split / num_times_split ,
1199+ accumulated_split ,
1200+ )
11621201 for accumulated_split , num_times_split in zip (
11631202 accumulated_values .split (self .context_length , dim = 1 ),
11641203 num_times_accumulated .split (self .context_length , dim = 1 ),
0 commit comments