Skip to content

Commit 226e65a

Browse files
author
Amit Raj
committed
Code cleanup and fixes
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 09f6f80 commit 226e65a

File tree

12 files changed

+178
-177
lines changed

12 files changed

+178
-177
lines changed

QEfficient/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def check_qaic_sdk():
5050
QEFFCommonLoader,
5151
)
5252
from QEfficient.compile.compile_helper import compile
53-
54-
# Imports for the diffusers
5553
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
5654
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
5755
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv

QEfficient/diffusers/models/autoencoders/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

QEfficient/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def forward(
200200

201201
context_ff_output = self.ff_context(norm_encoder_hidden_states)
202202
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
203-
# if encoder_hidden_states.dtype == torch.float16:
204-
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
205203

206204
return encoder_hidden_states, hidden_states
207205

@@ -257,11 +255,6 @@ def forward(
257255
if guidance is not None:
258256
guidance = guidance.to(hidden_states.dtype) * 1000
259257

260-
temb = (
261-
self.time_text_embed(timestep, pooled_projections)
262-
if guidance is None
263-
else self.time_text_embed(timestep, guidance, pooled_projections)
264-
)
265258
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
266259

267260
if txt_ids.ndim == 3:
@@ -286,24 +279,13 @@ def forward(
286279
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
287280

288281
for index_block, block in enumerate(self.transformer_blocks):
289-
if torch.is_grad_enabled() and self.gradient_checkpointing:
290-
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
291-
block,
292-
hidden_states,
293-
encoder_hidden_states,
294-
temb,
295-
image_rotary_emb,
296-
joint_attention_kwargs,
297-
)
298-
299-
else:
300-
encoder_hidden_states, hidden_states = block(
301-
hidden_states=hidden_states,
302-
encoder_hidden_states=encoder_hidden_states,
303-
temb=adaln_emb[index_block],
304-
image_rotary_emb=image_rotary_emb,
305-
joint_attention_kwargs=joint_attention_kwargs,
306-
)
282+
encoder_hidden_states, hidden_states = block(
283+
hidden_states=hidden_states,
284+
encoder_hidden_states=encoder_hidden_states,
285+
temb=adaln_emb[index_block],
286+
image_rotary_emb=image_rotary_emb,
287+
joint_attention_kwargs=joint_attention_kwargs,
288+
)
307289

308290
# controlnet residual
309291
if controlnet_block_samples is not None:
@@ -318,24 +300,13 @@ def forward(
318300
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
319301

320302
for index_block, block in enumerate(self.single_transformer_blocks):
321-
if torch.is_grad_enabled() and self.gradient_checkpointing:
322-
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
323-
block,
324-
hidden_states,
325-
encoder_hidden_states,
326-
temb,
327-
image_rotary_emb,
328-
joint_attention_kwargs,
329-
)
330-
331-
else:
332-
encoder_hidden_states, hidden_states = block(
333-
hidden_states=hidden_states,
334-
encoder_hidden_states=encoder_hidden_states,
335-
temb=adaln_single_emb[index_block],
336-
image_rotary_emb=image_rotary_emb,
337-
joint_attention_kwargs=joint_attention_kwargs,
338-
)
303+
encoder_hidden_states, hidden_states = block(
304+
hidden_states=hidden_states,
305+
encoder_hidden_states=encoder_hidden_states,
306+
temb=adaln_single_emb[index_block],
307+
image_rotary_emb=image_rotary_emb,
308+
joint_attention_kwargs=joint_attention_kwargs,
309+
)
339310

340311
# controlnet residual
341312
if controlnet_single_block_samples is not None:

0 commit comments

Comments
 (0)