Skip to content

Commit c19fcbd

Browse files
committed
Adjust INC to run from vLLM with old PA
Change-Id: Ifdea6840aaa22791f478ad10788e5d47fd4a0394
1 parent ff114b7 commit c19fcbd

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,19 +533,24 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
533533
self.fetch_from_cache = mod.fetch_from_cache
534534
self.forward = self.forward_measure
535535

536-
def forward(self, input, cache, block_indices, block_offset):
536+
def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset):
537537
qinput = self.quant_input(input)
538-
output_cache = self.forward_orig(qinput, cache, block_indices, block_offset)
538+
output_cache = self.forward_orig(qinput, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset)
539539
return self.quant_output(output_cache)
540540

541-
def forward_measure(self, input, cache, block_indices, block_offset):
541+
def forward_measure(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset):
542542
measure_input((input), self._mod_extra_config.inputs)
543-
output_cache = self.forward_orig(input, cache, block_indices, block_offset)
543+
output_cache = self.forward_orig(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset)
544544
measure_output((output_cache), self._mod_extra_config.outputs)
545545
return output_cache
546546

547-
def fetch_from_cache(self, cache, blocks):
547+
def fetch_from_cache(self, cache, blocks, permutations=None):
548548
quant_cache = self.quant_input(cache)
549+
if permutations:
550+
output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations)
551+
for i in range(len(output_cache)):
552+
output_cache[i]=self.quant_output(output_cache[i])
553+
return output_cache
549554
output_cache = self.orig_fetch_from_cache(quant_cache, blocks)
550555
return self.quant_output(output_cache)
551556

0 commit comments

Comments
 (0)