Skip to content

Commit 04d7a0b

Browse files
committed
fix quantized, add tests
1 parent 1c3cbcc commit 04d7a0b

File tree

3 files changed

+232
-136
lines changed

3 files changed

+232
-136
lines changed

src/transformers/cache_utils.py

Lines changed: 168 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
214214
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
215215

216216
def __repr__(self):
217-
return f"{self.__class__.__name__}(K={self.key_cache}, V={self.value_cache})"
217+
key_repr = "None" if self.key_cache is None else f"t({tuple(self.key_cache.shape)})"
218+
value_repr = "None" if self.value_cache is None else f"t({tuple(self.value_cache.shape)})"
219+
return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})"
218220

219221

220222
class Cache:
@@ -860,7 +862,7 @@ def update(
860862
def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int:
861863
"""Returns the sequence length of the cached states."""
862864
# TODO: deprecate this function in favor of `cache_position`
863-
if self is None or self.key_cache is None:
865+
if self is None or self.key_cache is None or self.key_cache.numel() == 0:
864866
return 0
865867
return self.key_cache.shape[-2]
866868

@@ -1017,94 +1019,6 @@ def __init__(self, config: Optional[CacheConfig] = None) -> None:
10171019
super().__init__(processors=processors, config=config)
10181020

10191021

1020-
class QuantoQuantizedCache(DynamicCache):
1021-
"""
1022-
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
1023-
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
1024-
1025-
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
1026-
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
1027-
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
1028-
1029-
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
1030-
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
1031-
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
1032-
1033-
Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
1034-
1035-
Parameters:
1036-
cache_config (`QuantizedCacheConfig`):
1037-
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
1038-
1039-
Example:
1040-
1041-
```python
1042-
>>> # Run pip install quanto first if you don't have it yet
1043-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
1044-
1045-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1046-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1047-
1048-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
1049-
1050-
>>> # Prepare a cache class and pass it to model's forward
1051-
>>> cache_config = QuantizedCacheConfig(nbits=4)
1052-
>>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
1053-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1054-
>>> outputs.past_key_values # access cache filled with key/values from generation
1055-
QuantoQuantizedCache()
1056-
```
1057-
"""
1058-
1059-
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
1060-
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
1061-
super(DynamicCache, self).__init__(processors=processors)
1062-
1063-
1064-
class HQQQuantizedCache(DynamicCache):
1065-
"""
1066-
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
1067-
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
1068-
1069-
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
1070-
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
1071-
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
1072-
1073-
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
1074-
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
1075-
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
1076-
1077-
Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
1078-
1079-
Parameters:
1080-
cache_config (`QuantizedCacheConfig`):
1081-
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
1082-
1083-
Example:
1084-
1085-
```python
1086-
>>> # Run pip install hqq first if you don't have it yet
1087-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
1088-
1089-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1090-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1091-
1092-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
1093-
1094-
>>> # Prepare a cache class and pass it to model's forward
1095-
>>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
1096-
>>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
1097-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1098-
>>> outputs.past_key_values # access cache filled with key/values from generation
1099-
HQQQuantizedCache()
1100-
```
1101-
"""
1102-
1103-
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
1104-
processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)])
1105-
super(DynamicCache, self).__init__(processors=processors)
1106-
1107-
11081022
class StaticLayer(CacheLayer):
11091023
is_compileable = True
11101024

@@ -2120,56 +2034,60 @@ def post_update(
21202034
if layer_idx == 0:
21212035
self._seen_tokens += key_tensors.shape[-2]
21222036

2123-
# Extend quantized cache if needed
2124-
while len(self._quantized_key_cache) <= layer_idx:
2125-
self._quantized_key_cache.append(torch.empty(0))
2126-
self._quantized_value_cache.append(torch.empty(0))
2037+
if len(cache.key_cache) < layer_idx:
2038+
raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
21272039

2128-
# Check if we need to quantize
2129-
if layer_idx < len(cache.key_cache):
2130-
current_key = cache.key_cache[layer_idx]
2131-
current_value = cache.value_cache[layer_idx]
2040+
# `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer
2041+
# On the first forward pass, we quantize the whole prompt.
2042+
# On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full.
2043+
is_prefill = self._get_quantized_length(layer_idx) == 0
21322044

2133-
if (
2134-
current_key.dim() == 4
2135-
and current_key.shape[-2] >= self.config.residual_length
2136-
and current_key.shape[-2] > self._get_quantized_length(layer_idx)
2137-
):
2138-
# Quantize the older part, keep recent tokens in original precision
2139-
split_idx = current_key.shape[-2] - self.config.residual_length
2045+
if is_prefill:
2046+
self._quantized_key_cache.append(self._quantize(key_tensors.contiguous(), axis=self.config.axis_key))
2047+
self._quantized_value_cache.append(self._quantize(value_tensors.contiguous(), axis=self.config.axis_value))
21402048

2141-
# Get the part to quantize
2142-
key_to_quantize = current_key[:, :, :split_idx, :].contiguous()
2143-
value_to_quantize = current_value[:, :, :split_idx, :].contiguous()
2049+
# Clear the residual cache
2050+
cache.key_cache[layer_idx] = torch.zeros(
2051+
0,
2052+
dtype=key_tensors.dtype,
2053+
device=key_tensors.device,
2054+
)
2055+
cache.value_cache[layer_idx] = torch.zeros(
2056+
0,
2057+
dtype=value_tensors.dtype,
2058+
device=value_tensors.device,
2059+
)
2060+
# On prefill, we return the original prompt
2061+
keys_to_return, values_to_return = key_tensors, value_tensors
21442062

2063+
else:
2064+
# Prepend the previously quantized cache
2065+
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
2066+
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
2067+
keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2)
2068+
values_to_return = torch.cat([dequant_value, value_tensors], dim=-2)
2069+
if key_tensors.shape[-2] >= self.config.residual_length:
21452070
# Quantize and store
2146-
self._quantized_key_cache[layer_idx] = self._quantize(key_to_quantize, axis=self.config.axis_key)
2147-
self._quantized_value_cache[layer_idx] = self._quantize(value_to_quantize, axis=self.config.axis_value)
2148-
2149-
# Keep only the recent tokens in original precision
2150-
cache.key_cache[layer_idx] = current_key[:, :, split_idx:, :]
2151-
cache.value_cache[layer_idx] = current_value[:, :, split_idx:, :]
2152-
2153-
# Return the full tensors for this update
2154-
if self._quantized_key_cache[layer_idx].numel() > 0:
2155-
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
2156-
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
2157-
full_key = torch.cat([dequant_key, cache.key_cache[layer_idx]], dim=-2)
2158-
full_value = torch.cat([dequant_value, cache.value_cache[layer_idx]], dim=-2)
2159-
return full_key, full_value
2071+
self._quantized_key_cache[layer_idx] = self._quantize(
2072+
keys_to_return.contiguous(), axis=self.config.axis_key
2073+
)
2074+
self._quantized_value_cache[layer_idx] = self._quantize(
2075+
values_to_return.contiguous(), axis=self.config.axis_value
2076+
)
21602077

2161-
return key_tensors, value_tensors
2078+
# Clear the residual cache
2079+
cache.key_cache[layer_idx] = torch.zeros(
2080+
0,
2081+
dtype=key_tensors.dtype,
2082+
device=key_tensors.device,
2083+
)
2084+
cache.value_cache[layer_idx] = torch.zeros(
2085+
0,
2086+
dtype=value_tensors.dtype,
2087+
device=value_tensors.device,
2088+
)
21622089

2163-
def _get_quantized_length(self, layer_idx: int) -> int:
2164-
"""Get the length of quantized cache for a layer."""
2165-
if layer_idx < len(self._quantized_key_cache) and self._quantized_key_cache[layer_idx].numel() > 0:
2166-
# This would depend on the specific quantization implementation
2167-
return (
2168-
self._quantized_key_cache[layer_idx].shape[-2]
2169-
if hasattr(self._quantized_key_cache[layer_idx], "shape")
2170-
else 0
2171-
)
2172-
return 0
2090+
return keys_to_return, values_to_return
21732091

21742092
def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor:
21752093
"""Quantize a tensor - to be implemented by specific quantization backends."""
@@ -2227,6 +2145,12 @@ def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor:
22272145
"""Dequantize tensor using quanto backend."""
22282146
return qtensor.dequantize()
22292147

2148+
def _get_quantized_length(self, layer_idx: int) -> int:
2149+
"""Get the length of quantized cache for a layer."""
2150+
if layer_idx < len(self._quantized_key_cache):
2151+
return self._quantized_key_cache[layer_idx].shape[-2]
2152+
return 0
2153+
22302154

22312155
class HQQQuantizedCacheProcessor(QuantizedCacheProcessor):
22322156
"""
@@ -2277,6 +2201,12 @@ def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tens
22772201
tensor = self.quantizer.dequantize(quant_tensor, meta)
22782202
return tensor
22792203

2204+
def _get_quantized_length(self, layer_idx: int) -> int:
2205+
"""Get the length of quantized cache for a layer."""
2206+
if layer_idx < len(self._quantized_key_cache):
2207+
return self._quantized_key_cache[layer_idx][0].shape[-2]
2208+
return 0
2209+
22802210

22812211
class QuantizedCache(DynamicCache):
22822212
"""
@@ -2293,9 +2223,113 @@ class QuantizedCache(DynamicCache):
22932223
"""
22942224

22952225
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2296-
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
2226+
if cache_config.backend == "quanto":
2227+
processor = QuantoQuantizedCacheProcessor(cache_config)
2228+
elif cache_config.backend == "hqq":
2229+
processor = HQQQuantizedCacheProcessor(cache_config)
2230+
else:
2231+
raise ValueError(f"Unknown quantization backend `{cache_config.backend}`")
2232+
2233+
processors = CacheProcessorList([processor])
22972234
super().__init__(processors=processors)
22982235

2236+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
2237+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
2238+
if len(self.key_cache) <= layer_idx:
2239+
return 0
2240+
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
2241+
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
2242+
# this part of code otherwise fails when used to verify attn_weight shape in some models
2243+
return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1
2244+
2245+
2246+
class QuantoQuantizedCache(QuantizedCache):
2247+
"""
2248+
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
2249+
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
2250+
2251+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
2252+
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
2253+
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
2254+
2255+
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
2256+
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
2257+
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
2258+
2259+
Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
2260+
2261+
Parameters:
2262+
cache_config (`QuantizedCacheConfig`):
2263+
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
2264+
2265+
Example:
2266+
2267+
```python
2268+
>>> # Run pip install quanto first if you don't have it yet
2269+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
2270+
2271+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2272+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2273+
2274+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
2275+
2276+
>>> # Prepare a cache class and pass it to model's forward
2277+
>>> cache_config = QuantizedCacheConfig(nbits=4)
2278+
>>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
2279+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
2280+
>>> outputs.past_key_values # access cache filled with key/values from generation
2281+
QuantoQuantizedCache()
2282+
```
2283+
"""
2284+
2285+
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2286+
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
2287+
Cache.__init__(self, processors=processors)
2288+
2289+
2290+
class HQQQuantizedCache(QuantizedCache):
2291+
"""
2292+
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
2293+
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
2294+
2295+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
2296+
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
2297+
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
2298+
2299+
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
2300+
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
2301+
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
2302+
2303+
Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
2304+
2305+
Parameters:
2306+
cache_config (`QuantizedCacheConfig`):
2307+
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
2308+
2309+
Example:
2310+
2311+
```python
2312+
>>> # Run pip install hqq first if you don't have it yet
2313+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
2314+
2315+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2316+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2317+
2318+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
2319+
2320+
>>> # Prepare a cache class and pass it to model's forward
2321+
>>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
2322+
>>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
2323+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
2324+
>>> outputs.past_key_values # access cache filled with key/values from generation
2325+
HQQQuantizedCache()
2326+
```
2327+
"""
2328+
2329+
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2330+
processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)])
2331+
Cache.__init__(self, processors=processors)
2332+
22992333

23002334
class SinkCache(Cache):
23012335
"""

0 commit comments

Comments
 (0)