Skip to content

Commit

Permalink
Cache: don't throw warnings on gemma2 when instantiating a new cache (
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Sep 19, 2024
1 parent b50ff59 commit 52920b5
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 30 deletions.
10 changes: 9 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,15 @@ def get_max_length(self) -> Optional[int]:
return self.max_cache_len

def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if layer_idx != 0:
raise ValueError(
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def reset(self):
"""Resets the cache values while preserving the objects"""
Expand Down
41 changes: 14 additions & 27 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,20 +710,13 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
past_key_values (`HybridCache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
cache classes.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
Expand Down Expand Up @@ -789,7 +782,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -818,19 +811,8 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
if past_key_values is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
else:
raise ValueError("When `past_key_values` is passed, `cache_position` must be too")

# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
logger.warning_once(
"You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ",
"If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance "
"will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)",
)
# Instantiate an empty cache if needed.
if use_cache and past_key_values is None:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
Expand All @@ -840,6 +822,11 @@ def forward(
dtype=inputs_embeds.dtype,
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

Expand Down Expand Up @@ -912,7 +899,7 @@ def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
past_key_values: HybridCache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
Expand Down Expand Up @@ -981,7 +968,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1202,7 +1189,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,16 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if use_cache and not isinstance(past_key_values, Cache):
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
5 changes: 5 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ def setUp(self):
def test_model_outputs_equivalence(self, **kwargs):
pass

@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass

@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass

@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
Expand Down

0 comments on commit 52920b5

Please sign in to comment.