@@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188188 self ._chunk_dim = dim
189189
190190 def forward (
191- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor
191+ self ,
192+ hidden_states : torch .FloatTensor ,
193+ encoder_hidden_states : torch .FloatTensor ,
194+ temb : torch .FloatTensor ,
195+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
192196 ):
197+ joint_attention_kwargs = joint_attention_kwargs or {}
193198 if self .use_dual_attention :
194199 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp , norm_hidden_states2 , gate_msa2 = self .norm1 (
195200 hidden_states , emb = temb
@@ -206,15 +211,17 @@ def forward(
206211
207212 # Attention.
208213 attn_output , context_attn_output = self .attn (
209- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
214+ hidden_states = norm_hidden_states ,
215+ encoder_hidden_states = norm_encoder_hidden_states ,
216+ ** joint_attention_kwargs ,
210217 )
211218
212219 # Process attention outputs for the `hidden_states`.
213220 attn_output = gate_msa .unsqueeze (1 ) * attn_output
214221 hidden_states = hidden_states + attn_output
215222
216223 if self .use_dual_attention :
217- attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 )
224+ attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 , ** joint_attention_kwargs )
218225 attn_output2 = gate_msa2 .unsqueeze (1 ) * attn_output2
219226 hidden_states = hidden_states + attn_output2
220227
0 commit comments