Skip to content

Commit 22bd6b0

Browse files
authored
Fix Flux tuning issue (#936)
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
1 parent 873114a commit 22bd6b0

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

auto_round/compressors/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,9 +2527,9 @@ def _get_current_num_elm(
25272527
def _quantize_block(
25282528
self,
25292529
block: torch.nn.Module,
2530-
input_ids: list[torch.Tensor],
2530+
input_ids: Union[list[torch.Tensor], dict],
25312531
input_others: dict,
2532-
q_input: Union[None, torch.Tensor] = None,
2532+
q_input: Union[torch.Tensor, dict, None] = None,
25332533
device: Union[str, torch.device] = "cpu",
25342534
):
25352535
"""Quantize the weights of a given block of the model.
@@ -2646,7 +2646,11 @@ def _quantize_block(
26462646
else:
26472647
lr_schedule = copy.deepcopy(self.lr_scheduler)
26482648

2649-
nsamples = len(input_ids)
2649+
if isinstance(input_ids, dict): # input_ids of Flux is dict
2650+
nsamples = len(input_ids["hidden_states"])
2651+
else:
2652+
nsamples = len(input_ids)
2653+
26502654
pick_samples = self.batch_size * self.gradient_accumulate_steps
26512655
pick_samples = min(nsamples, pick_samples)
26522656
if self.sampler != "rand":

auto_round/compressors/diffusion/compressor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _get_current_q_output(
210210
def _get_block_outputs(
211211
self,
212212
block: torch.nn.Module,
213-
input_ids: torch.Tensor,
213+
input_ids: Union[torch.Tensor, dict],
214214
input_others: torch.Tensor,
215215
bs: int,
216216
device: Union[str, torch.device],
@@ -233,8 +233,11 @@ def _get_block_outputs(
233233
"""
234234

235235
output = defaultdict(list)
236-
nsamples = len(input_ids)
237236
output_config = output_configs.get(block.__class__.__name__, [])
237+
if isinstance(input_ids, dict):
238+
nsamples = len(input_ids["hidden_states"])
239+
else:
240+
nsamples = len(input_ids)
238241

239242
for i in range(0, nsamples, bs):
240243
end_index = min(nsamples, i + bs)

0 commit comments

Comments
 (0)