File tree Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -2527,9 +2527,9 @@ def _get_current_num_elm(
25272527 def _quantize_block (
25282528 self ,
25292529 block : torch .nn .Module ,
2530- input_ids : list [torch .Tensor ],
2530+ input_ids : Union [ list [torch .Tensor ], dict ],
25312531 input_others : dict ,
2532- q_input : Union [None , torch .Tensor ] = None ,
2532+ q_input : Union [torch .Tensor , dict , None ] = None ,
25332533 device : Union [str , torch .device ] = "cpu" ,
25342534 ):
25352535 """Quantize the weights of a given block of the model.
@@ -2646,7 +2646,11 @@ def _quantize_block(
26462646 else :
26472647 lr_schedule = copy .deepcopy (self .lr_scheduler )
26482648
2649- nsamples = len (input_ids )
2649+ if isinstance (input_ids , dict ): # input_ids of Flux is dict
2650+ nsamples = len (input_ids ["hidden_states" ])
2651+ else :
2652+ nsamples = len (input_ids )
2653+
26502654 pick_samples = self .batch_size * self .gradient_accumulate_steps
26512655 pick_samples = min (nsamples , pick_samples )
26522656 if self .sampler != "rand" :
Original file line number Diff line number Diff line change @@ -210,7 +210,7 @@ def _get_current_q_output(
210210 def _get_block_outputs (
211211 self ,
212212 block : torch .nn .Module ,
213- input_ids : torch .Tensor ,
213+ input_ids : Union [ torch .Tensor , dict ] ,
214214 input_others : torch .Tensor ,
215215 bs : int ,
216216 device : Union [str , torch .device ],
@@ -233,8 +233,11 @@ def _get_block_outputs(
233233 """
234234
235235 output = defaultdict (list )
236- nsamples = len (input_ids )
237236 output_config = output_configs .get (block .__class__ .__name__ , [])
237+ if isinstance (input_ids , dict ):
238+ nsamples = len (input_ids ["hidden_states" ])
239+ else :
240+ nsamples = len (input_ids )
238241
239242 for i in range (0 , nsamples , bs ):
240243 end_index = min (nsamples , i + bs )
You can’t perform that action at this time.
0 commit comments