Skip to content

Commit a2fe24c

Browse files
committed
fix quantized, add tests
1 parent 1c3cbcc commit a2fe24c

File tree

3 files changed

+221
-122
lines changed

3 files changed

+221
-122
lines changed

src/transformers/cache_utils.py

Lines changed: 153 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
214214
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
215215

216216
def __repr__(self):
217-
return f"{self.__class__.__name__}(K={self.key_cache}, V={self.value_cache})"
217+
key_repr = "None" if self.key_cache is None else f"t({tuple(self.key_cache.shape)})"
218+
value_repr = "None" if self.value_cache is None else f"t({tuple(self.value_cache.shape)})"
219+
return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})"
218220

219221

220222
class Cache:
@@ -860,7 +862,7 @@ def update(
860862
def get_seq_length(self, cache_position: Optional[torch.LongTensor] = None) -> int:
861863
"""Returns the sequence length of the cached states."""
862864
# TODO: deprecate this function in favor of `cache_position`
863-
if self is None or self.key_cache is None:
865+
if self is None or self.key_cache is None or self.key_cache.numel() == 0:
864866
return 0
865867
return self.key_cache.shape[-2]
866868

@@ -1017,94 +1019,6 @@ def __init__(self, config: Optional[CacheConfig] = None) -> None:
10171019
super().__init__(processors=processors, config=config)
10181020

10191021

1020-
class QuantoQuantizedCache(DynamicCache):
1021-
"""
1022-
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
1023-
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
1024-
1025-
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
1026-
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
1027-
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
1028-
1029-
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
1030-
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
1031-
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
1032-
1033-
Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
1034-
1035-
Parameters:
1036-
cache_config (`QuantizedCacheConfig`):
1037-
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
1038-
1039-
Example:
1040-
1041-
```python
1042-
>>> # Run pip install quanto first if you don't have it yet
1043-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
1044-
1045-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1046-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1047-
1048-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
1049-
1050-
>>> # Prepare a cache class and pass it to model's forward
1051-
>>> cache_config = QuantizedCacheConfig(nbits=4)
1052-
>>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
1053-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1054-
>>> outputs.past_key_values # access cache filled with key/values from generation
1055-
QuantoQuantizedCache()
1056-
```
1057-
"""
1058-
1059-
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
1060-
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
1061-
super(DynamicCache, self).__init__(processors=processors)
1062-
1063-
1064-
class HQQQuantizedCache(DynamicCache):
1065-
"""
1066-
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
1067-
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
1068-
1069-
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
1070-
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
1071-
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
1072-
1073-
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
1074-
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
1075-
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
1076-
1077-
Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
1078-
1079-
Parameters:
1080-
cache_config (`QuantizedCacheConfig`):
1081-
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
1082-
1083-
Example:
1084-
1085-
```python
1086-
>>> # Run pip install hqq first if you don't have it yet
1087-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
1088-
1089-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1090-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1091-
1092-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
1093-
1094-
>>> # Prepare a cache class and pass it to model's forward
1095-
>>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
1096-
>>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
1097-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1098-
>>> outputs.past_key_values # access cache filled with key/values from generation
1099-
HQQQuantizedCache()
1100-
```
1101-
"""
1102-
1103-
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
1104-
processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)])
1105-
super(DynamicCache, self).__init__(processors=processors)
1106-
1107-
11081022
class StaticLayer(CacheLayer):
11091023
is_compileable = True
11101024

@@ -2125,40 +2039,55 @@ def post_update(
21252039
self._quantized_key_cache.append(torch.empty(0))
21262040
self._quantized_value_cache.append(torch.empty(0))
21272041

2128-
# Check if we need to quantize
2129-
if layer_idx < len(cache.key_cache):
2130-
current_key = cache.key_cache[layer_idx]
2131-
current_value = cache.value_cache[layer_idx]
2042+
# `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer
2043+
# On the first forward pass, we quantize the whole prompt.
2044+
# On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full.
2045+
is_prefill = self._get_quantized_length(layer_idx) == 0
21322046

2133-
if (
2134-
current_key.dim() == 4
2135-
and current_key.shape[-2] >= self.config.residual_length
2136-
and current_key.shape[-2] > self._get_quantized_length(layer_idx)
2137-
):
2138-
# Quantize the older part, keep recent tokens in original precision
2139-
split_idx = current_key.shape[-2] - self.config.residual_length
2047+
if is_prefill:
2048+
self._quantized_key_cache[layer_idx] = self._quantize(key_tensors.contiguous(), axis=self.config.axis_key)
2049+
self._quantized_value_cache[layer_idx] = self._quantize(
2050+
value_tensors.contiguous(), axis=self.config.axis_value
2051+
)
21402052

2141-
# Get the part to quantize
2142-
key_to_quantize = current_key[:, :, :split_idx, :].contiguous()
2143-
value_to_quantize = current_value[:, :, :split_idx, :].contiguous()
2053+
# Clear the residual cache
2054+
cache.key_cache[layer_idx] = torch.zeros(
2055+
0,
2056+
dtype=key_tensors.dtype,
2057+
device=key_tensors.device,
2058+
)
2059+
cache.value_cache[layer_idx] = torch.zeros(
2060+
0,
2061+
dtype=value_tensors.dtype,
2062+
device=value_tensors.device,
2063+
)
2064+
# On prefill, we return the original prompt
2065+
keys_to_return, values_to_return = key_tensors, value_tensors
21442066

2067+
else:
2068+
# Prepend the previously quantized cache
2069+
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
2070+
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
2071+
keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2)
2072+
values_to_return = torch.cat([dequant_value, value_tensors], dim=-2)
2073+
if key_tensors.shape[-2] >= self.config.residual_length:
21452074
# Quantize and store
2146-
self._quantized_key_cache[layer_idx] = self._quantize(key_to_quantize, axis=self.config.axis_key)
2147-
self._quantized_value_cache[layer_idx] = self._quantize(value_to_quantize, axis=self.config.axis_value)
2148-
2149-
# Keep only the recent tokens in original precision
2150-
cache.key_cache[layer_idx] = current_key[:, :, split_idx:, :]
2151-
cache.value_cache[layer_idx] = current_value[:, :, split_idx:, :]
2075+
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.config.axis_key)
2076+
self._quantized_value_cache[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.config.axis_value)
2077+
2078+
# Clear the residual cache
2079+
cache.key_cache[layer_idx] = torch.zeros(
2080+
0,
2081+
dtype=key_tensors.dtype,
2082+
device=key_tensors.device,
2083+
)
2084+
cache.value_cache[layer_idx] = torch.zeros(
2085+
0,
2086+
dtype=value_tensors.dtype,
2087+
device=value_tensors.device,
2088+
)
21522089

2153-
# Return the full tensors for this update
2154-
if self._quantized_key_cache[layer_idx].numel() > 0:
2155-
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
2156-
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
2157-
full_key = torch.cat([dequant_key, cache.key_cache[layer_idx]], dim=-2)
2158-
full_value = torch.cat([dequant_value, cache.value_cache[layer_idx]], dim=-2)
2159-
return full_key, full_value
2160-
2161-
return key_tensors, value_tensors
2090+
return keys_to_return, values_to_return
21622091

21632092
def _get_quantized_length(self, layer_idx: int) -> int:
21642093
"""Get the length of quantized cache for a layer."""
@@ -2293,9 +2222,113 @@ class QuantizedCache(DynamicCache):
22932222
"""
22942223

22952224
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2296-
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
2225+
if cache_config.backend == "quanto":
2226+
processor = QuantoQuantizedCacheProcessor(cache_config)
2227+
elif cache_config.backend == "hqq":
2228+
processor = HQQQuantizedCacheProcessor(cache_config)
2229+
else:
2230+
raise ValueError(f"Unknown quantization backend `{cache_config.backend}`")
2231+
2232+
processors = CacheProcessorList([processor])
22972233
super().__init__(processors=processors)
22982234

2235+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
2236+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
2237+
if len(self.key_cache) <= layer_idx:
2238+
return 0
2239+
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
2240+
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
2241+
# this part of code otherwise fails when used to verify attn_weight shape in some models
2242+
return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1
2243+
2244+
2245+
class QuantoQuantizedCache(QuantizedCache):
2246+
"""
2247+
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
2248+
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
2249+
2250+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
2251+
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
2252+
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
2253+
2254+
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
2255+
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
2256+
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
2257+
2258+
Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
2259+
2260+
Parameters:
2261+
cache_config (`QuantizedCacheConfig`):
2262+
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
2263+
2264+
Example:
2265+
2266+
```python
2267+
>>> # Run pip install quanto first if you don't have it yet
2268+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
2269+
2270+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2271+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2272+
2273+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
2274+
2275+
>>> # Prepare a cache class and pass it to model's forward
2276+
>>> cache_config = QuantizedCacheConfig(nbits=4)
2277+
>>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
2278+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
2279+
>>> outputs.past_key_values # access cache filled with key/values from generation
2280+
QuantoQuantizedCache()
2281+
```
2282+
"""
2283+
2284+
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2285+
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
2286+
Cache.__init__(self, processors=processors)
2287+
2288+
2289+
class HQQQuantizedCache(QuantizedCache):
2290+
"""
2291+
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
2292+
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
2293+
2294+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
2295+
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
2296+
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
2297+
2298+
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
2299+
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
2300+
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
2301+
2302+
Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
2303+
2304+
Parameters:
2305+
cache_config (`QuantizedCacheConfig`):
2306+
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
2307+
2308+
Example:
2309+
2310+
```python
2311+
>>> # Run pip install hqq first if you don't have it yet
2312+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
2313+
2314+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2315+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
2316+
2317+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
2318+
2319+
>>> # Prepare a cache class and pass it to model's forward
2320+
>>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
2321+
>>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
2322+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
2323+
>>> outputs.past_key_values # access cache filled with key/values from generation
2324+
HQQQuantizedCache()
2325+
```
2326+
"""
2327+
2328+
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2329+
processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)])
2330+
Cache.__init__(self, processors=processors)
2331+
22992332

23002333
class SinkCache(Cache):
23012334
"""

src/transformers/generation/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,12 @@ def prepare_inputs_for_generation(
574574
# function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
575575
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
576576
elif cache_position is None:
577-
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
577+
past_length = 0
578+
if past_key_values is not None:
579+
if not isinstance(past_key_values, Cache):
580+
past_length = past_key_values[0][0].shape[2]
581+
elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() is not None:
582+
past_length = past_key_values.get_seq_length()
578583
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
579584

580585
# 2. Generic cache-dependent input preparation

0 commit comments

Comments
 (0)