Skip to content

Commit 7c0a3e2

Browse files
committed
fix vllmkvcache
Change-Id: I2916bb6d9c1c6b70be115d6d2b78959a0681f63f Signed-off-by: Yi Liu <yiliu4@habana.ai>
1 parent 1f8bcd6 commit 7c0a3e2

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def prepare_model(model, mod_list=None):
108108
d_shapes = None
109109
gmod_list.extend(mod_list)
110110
generate_model_info(model)
111-
logger.info(f"generated model info")
112-
for mod, name in parent_child_mod_dict.items():
113-
logger.info(f"mod: {mod}, name: {name}")
111+
# logger.info(f"generated model info")
112+
# for mod, name in parent_child_mod_dict.items():
113+
# logger.info(f"mod: {mod}, name: {name}")
114114
register_patched_measure_modules(model, mod_list, observer_class, d_shapes)
115115

116116

@@ -160,7 +160,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
160160
)
161161
logger.info(f"Patching measure module {name} {mod.__class__} ")
162162
pmod = patch_module_measure(mod, mod_extra_config, mod_default_dict)
163-
logger.info(f"Pacthed module pmod: {pmod}")
163+
# logger.info(f"Pacthed module pmod: {pmod}")
164164
if pmod._mod_extra_config:
165165
for param_name in pmod._mod_extra_config.params:
166166
param = getattr(pmod, param_name)

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -878,26 +878,43 @@ def forward_qdq(self, input, *args, **kwargs):
878878
output_cache = self.orig_mod(qinput, *args, **kwargs)
879879
return output_cache
880880

881-
def forward_quant(self, input, *args, **kwargs):
882-
qinput = self.quant_input(input)
883-
output_cache = self.orig_mod(qinput, *args, **kwargs)
884-
return self.dequant_output(output_cache)
881+
# def forward_quant(self, input, *args, **kwargs):
882+
# qinput = self.quant_input(input)
883+
# output_cache = self.orig_mod(qinput, *args, **kwargs)
884+
# return self.dequant_output(output_cache)
885885

886886
def forward_measure(self, input, *args, **kwargs):
887887
measure_input((input), self._mod_extra_config.inputs)
888888
output_cache = self.orig_mod(input, *args, **kwargs)
889889
measure_output((output_cache), self._mod_extra_config.outputs)
890890
return output_cache
891891

892-
def fetch_from_cache(self, cache, blocks, permutations=None):
893-
quant_cache = self.quant_input(cache)
892+
# def fetch_from_cache(self, cache, blocks, permutations=None):
893+
# # quant_cache = self.quant_input(cache)
894+
# quant_cache = cache
895+
# if permutations:
896+
# output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
897+
# for i in range(len(output_cache)):
898+
# output_cache[i] = self.dequant_output(output_cache[i])
899+
# return output_cache
900+
# output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
901+
# return self.dequant_output(output_cache)
902+
903+
def forward_quant(self, input, *args, **kwargs):
904+
qinput = self.quant_input(input)
905+
return self.orig_mod(qinput, *args, **kwargs)
906+
907+
def fetch_from_cache(self, quant_cache, blocks, permutations=None):
894908
if permutations:
895909
output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
896910
for i in range(len(output_cache)):
897911
output_cache[i] = self.dequant_output(output_cache[i])
898912
return output_cache
899913
output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
900914
return self.dequant_output(output_cache)
915+
916+
def extra_repr(self) -> str:
917+
return f"PatchedVLLMKVCache"
901918

902919
def init_conv(instance, mod_extra_config):
903920
if instance.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:

0 commit comments

Comments
 (0)