@@ -200,8 +200,6 @@ def forward(
200200
201201 context_ff_output = self .ff_context (norm_encoder_hidden_states )
202202 encoder_hidden_states = encoder_hidden_states + c_gate_mlp .unsqueeze (1 ) * context_ff_output
203- # if encoder_hidden_states.dtype == torch.float16:
204- encoder_hidden_states = encoder_hidden_states .clip (- 65504 , 65504 )
205203
206204 return encoder_hidden_states , hidden_states
207205
@@ -257,11 +255,6 @@ def forward(
257255 if guidance is not None :
258256 guidance = guidance .to (hidden_states .dtype ) * 1000
259257
260- temb = (
261- self .time_text_embed (timestep , pooled_projections )
262- if guidance is None
263- else self .time_text_embed (timestep , guidance , pooled_projections )
264- )
265258 encoder_hidden_states = self .context_embedder (encoder_hidden_states )
266259
267260 if txt_ids .ndim == 3 :
@@ -286,24 +279,13 @@ def forward(
286279 joint_attention_kwargs .update ({"ip_hidden_states" : ip_hidden_states })
287280
288281 for index_block , block in enumerate (self .transformer_blocks ):
289- if torch .is_grad_enabled () and self .gradient_checkpointing :
290- encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
291- block ,
292- hidden_states ,
293- encoder_hidden_states ,
294- temb ,
295- image_rotary_emb ,
296- joint_attention_kwargs ,
297- )
298-
299- else :
300- encoder_hidden_states , hidden_states = block (
301- hidden_states = hidden_states ,
302- encoder_hidden_states = encoder_hidden_states ,
303- temb = adaln_emb [index_block ],
304- image_rotary_emb = image_rotary_emb ,
305- joint_attention_kwargs = joint_attention_kwargs ,
306- )
282+ encoder_hidden_states , hidden_states = block (
283+ hidden_states = hidden_states ,
284+ encoder_hidden_states = encoder_hidden_states ,
285+ temb = adaln_emb [index_block ],
286+ image_rotary_emb = image_rotary_emb ,
287+ joint_attention_kwargs = joint_attention_kwargs ,
288+ )
307289
308290 # controlnet residual
309291 if controlnet_block_samples is not None :
@@ -318,24 +300,13 @@ def forward(
318300 hidden_states = hidden_states + controlnet_block_samples [index_block // interval_control ]
319301
320302 for index_block , block in enumerate (self .single_transformer_blocks ):
321- if torch .is_grad_enabled () and self .gradient_checkpointing :
322- encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
323- block ,
324- hidden_states ,
325- encoder_hidden_states ,
326- temb ,
327- image_rotary_emb ,
328- joint_attention_kwargs ,
329- )
330-
331- else :
332- encoder_hidden_states , hidden_states = block (
333- hidden_states = hidden_states ,
334- encoder_hidden_states = encoder_hidden_states ,
335- temb = adaln_single_emb [index_block ],
336- image_rotary_emb = image_rotary_emb ,
337- joint_attention_kwargs = joint_attention_kwargs ,
338- )
303+ encoder_hidden_states , hidden_states = block (
304+ hidden_states = hidden_states ,
305+ encoder_hidden_states = encoder_hidden_states ,
306+ temb = adaln_single_emb [index_block ],
307+ image_rotary_emb = image_rotary_emb ,
308+ joint_attention_kwargs = joint_attention_kwargs ,
309+ )
339310
340311 # controlnet residual
341312 if controlnet_single_block_samples is not None :
0 commit comments