@@ -188,7 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188188 self ._chunk_dim = dim
189189
190190 def forward (
191- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor
191+ self ,
192+ hidden_states : torch .FloatTensor ,
193+ encoder_hidden_states : torch .FloatTensor ,
194+ temb : torch .FloatTensor ,
195+ joint_attention_kwargs : Dict [str , Any ] = None ,
192196 ):
193197 if self .use_dual_attention :
194198 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp , norm_hidden_states2 , gate_msa2 = self .norm1 (
@@ -206,15 +210,17 @@ def forward(
206210
207211 # Attention.
208212 attn_output , context_attn_output = self .attn (
209- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
213+ hidden_states = norm_hidden_states ,
214+ encoder_hidden_states = norm_encoder_hidden_states ,
215+ ** joint_attention_kwargs ,
210216 )
211217
212218 # Process attention outputs for the `hidden_states`.
213219 attn_output = gate_msa .unsqueeze (1 ) * attn_output
214220 hidden_states = hidden_states + attn_output
215221
216222 if self .use_dual_attention :
217- attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 )
223+ attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 , ** joint_attention_kwargs )
218224 attn_output2 = gate_msa2 .unsqueeze (1 ) * attn_output2
219225 hidden_states = hidden_states + attn_output2
220226
0 commit comments