@@ -533,19 +533,24 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
533
533
self .fetch_from_cache = mod .fetch_from_cache
534
534
self .forward = self .forward_measure
535
535
536
- def forward (self , input , cache , block_indices , block_offset ):
536
+ def forward (self , input , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset ):
537
537
qinput = self .quant_input (input )
538
- output_cache = self .forward_orig (qinput , cache , block_indices , block_offset )
538
+ output_cache = self .forward_orig (qinput , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset )
539
539
return self .quant_output (output_cache )
540
540
541
- def forward_measure (self , input , cache , block_indices , block_offset ):
541
+ def forward_measure (self , input , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset ):
542
542
measure_input ((input ), self ._mod_extra_config .inputs )
543
- output_cache = self .forward_orig (input , cache , block_indices , block_offset )
543
+ output_cache = self .forward_orig (input , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset )
544
544
measure_output ((output_cache ), self ._mod_extra_config .outputs )
545
545
return output_cache
546
546
547
- def fetch_from_cache (self , cache , blocks ):
547
+ def fetch_from_cache (self , cache , blocks , permutations = None ):
548
548
quant_cache = self .quant_input (cache )
549
+ if permutations :
550
+ output_cache = self .orig_fetch_from_cache (quant_cache , blocks , permutations )
551
+ for i in range (len (output_cache )):
552
+ output_cache [i ]= self .quant_output (output_cache [i ])
553
+ return output_cache
549
554
output_cache = self .orig_fetch_from_cache (quant_cache , blocks )
550
555
return self .quant_output (output_cache )
551
556
0 commit comments