@@ -878,26 +878,43 @@ def forward_qdq(self, input, *args, **kwargs):
878878 output_cache = self .orig_mod (qinput , * args , ** kwargs )
879879 return output_cache
880880
881- def forward_quant (self , input , * args , ** kwargs ):
882- qinput = self .quant_input (input )
883- output_cache = self .orig_mod (qinput , * args , ** kwargs )
884- return self .dequant_output (output_cache )
881+ # def forward_quant(self, input, *args, **kwargs):
882+ # qinput = self.quant_input(input)
883+ # output_cache = self.orig_mod(qinput, *args, **kwargs)
884+ # return self.dequant_output(output_cache)
885885
886886 def forward_measure (self , input , * args , ** kwargs ):
887887 measure_input ((input ), self ._mod_extra_config .inputs )
888888 output_cache = self .orig_mod (input , * args , ** kwargs )
889889 measure_output ((output_cache ), self ._mod_extra_config .outputs )
890890 return output_cache
891891
892- def fetch_from_cache (self , cache , blocks , permutations = None ):
893- quant_cache = self .quant_input (cache )
892+ # def fetch_from_cache(self, cache, blocks, permutations=None):
893+ # # quant_cache = self.quant_input(cache)
894+ # quant_cache = cache
895+ # if permutations:
896+ # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
897+ # for i in range(len(output_cache)):
898+ # output_cache[i] = self.dequant_output(output_cache[i])
899+ # return output_cache
900+ # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
901+ # return self.dequant_output(output_cache)
902+
903+ def forward_quant (self , input , * args , ** kwargs ):
904+ qinput = self .quant_input (input )
905+ return self .orig_mod (qinput , * args , ** kwargs )
906+
907+ def fetch_from_cache (self , quant_cache , blocks , permutations = None ):
894908 if permutations :
895909 output_cache = self .orig_mod .fetch_from_cache (quant_cache , blocks , permutations )
896910 for i in range (len (output_cache )):
897911 output_cache [i ] = self .dequant_output (output_cache [i ])
898912 return output_cache
899913 output_cache = self .orig_mod .fetch_from_cache (quant_cache , blocks )
900914 return self .dequant_output (output_cache )
915+
916+ def extra_repr (self ) -> str :
917+ return f"PatchedVLLMKVCache"
901918
902919def init_conv (instance , mod_extra_config ):
903920 if instance .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
0 commit comments