Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
193 commits
Select commit Hold shift + click to select a range
e083d5c
start
Cyrilvallez Apr 29, 2025
e1d43c4
start having a clean 4d mask primitive
Cyrilvallez Apr 29, 2025
59a69c4
Update mask_utils.py
Cyrilvallez Apr 29, 2025
8aa61b0
Update mask_utils.py
Cyrilvallez Apr 29, 2025
ee7bafd
switch name
Cyrilvallez Apr 29, 2025
bfcc5d8
Update masking_utils.py
Cyrilvallez Apr 29, 2025
f92757a
add a new AttentionMask tensor class
ArthurZucker Apr 30, 2025
932c17b
fix import
ArthurZucker Apr 30, 2025
1227356
nits
ArthurZucker Apr 30, 2025
542d054
fixes
ArthurZucker Apr 30, 2025
99235bb
use full and quandrants
ArthurZucker Apr 30, 2025
6d16c6b
general sdpa mask for all caches
Cyrilvallez Apr 30, 2025
c98bc68
style
Cyrilvallez Apr 30, 2025
f3027fe
start some tests
Cyrilvallez Apr 30, 2025
7eeed31
tests with sliding, chunked
Cyrilvallez Apr 30, 2025
ddd6059
add styling
ArthurZucker May 1, 2025
9397d17
test hybrid
Cyrilvallez Apr 30, 2025
bb6ea15
Update masking_utils.py
Cyrilvallez Apr 30, 2025
d7c4fa7
small temp fixes
Cyrilvallez Apr 30, 2025
3ea388c
Update modeling_gemma2.py
Cyrilvallez Apr 30, 2025
2165232
compile compatible
Cyrilvallez May 5, 2025
07fda4f
Update masking_utils.py
Cyrilvallez May 5, 2025
eed6383
improve
Cyrilvallez May 5, 2025
15485c3
start making it more general
Cyrilvallez May 5, 2025
ce4080b
Update masking_utils.py
Cyrilvallez May 5, 2025
039e444
generate
Cyrilvallez May 6, 2025
8b99dde
make it work with flex style primitives!
Cyrilvallez May 6, 2025
14f0163
Update masking_utils.py
Cyrilvallez May 6, 2025
6bb8742
Update masking_utils.py
Cyrilvallez May 7, 2025
1c16acb
Update masking_utils.py
Cyrilvallez May 7, 2025
8f75abb
improve
Cyrilvallez May 7, 2025
387cdb4
Update cache_utils.py
Cyrilvallez May 7, 2025
346bfa9
Update masking_utils.py
Cyrilvallez May 7, 2025
b826d91
simplify - starting to look good!
Cyrilvallez May 7, 2025
83f20d3
Update masking_utils.py
Cyrilvallez May 7, 2025
6fa1d35
name
Cyrilvallez May 7, 2025
05777fb
Update masking_utils.py
Cyrilvallez May 7, 2025
6ab437e
style
Cyrilvallez May 7, 2025
1d6c900
Update masking_utils.py
Cyrilvallez May 7, 2025
b5e1ebd
Update masking_utils.py
Cyrilvallez May 7, 2025
26a428c
Update masking_utils.py
Cyrilvallez May 7, 2025
71162e4
Update masking_utils.py
Cyrilvallez May 7, 2025
28d5a19
small fix for flex
Cyrilvallez May 7, 2025
8532033
flex compile
Cyrilvallez May 7, 2025
f3c8e7c
FA2
Cyrilvallez May 8, 2025
d0c0b40
Update masking_utils.py
Cyrilvallez May 8, 2025
bbc6bec
Escape for TGI/vLLM!
Cyrilvallez May 8, 2025
ada1627
Update masking_utils.py
Cyrilvallez May 8, 2025
1620f1e
Update masking_utils.py
Cyrilvallez May 8, 2025
8f3a2a0
Update masking_utils.py
Cyrilvallez May 8, 2025
d5b3285
General case without cache
Cyrilvallez May 8, 2025
ce22728
rename
Cyrilvallez May 8, 2025
7bf1352
full test on llama4
Cyrilvallez May 8, 2025
05529ff
small fix for FA2 guard with chunk
Cyrilvallez May 8, 2025
5afd898
Update modeling_gemma2.py
Cyrilvallez May 8, 2025
67bea3f
post rebase cleanup
Cyrilvallez May 9, 2025
7bb501d
FA2 supports static cache!
Cyrilvallez May 9, 2025
9bb6219
Update modeling_flash_attention_utils.py
Cyrilvallez May 9, 2025
f4849ab
Update flex_attention.py
Cyrilvallez May 12, 2025
44f9b65
Update masking_utils.py
Cyrilvallez May 12, 2025
dc52eb3
Update masking_utils.py
Cyrilvallez May 12, 2025
d2f645d
Update utils.py
Cyrilvallez May 12, 2025
07bd06c
override for export
Cyrilvallez May 12, 2025
f6735f2
Update executorch.py
Cyrilvallez May 12, 2025
f2d8a54
Update executorch.py
Cyrilvallez May 12, 2025
ee0afdd
Update executorch.py
Cyrilvallez May 12, 2025
73549e5
Update executorch.py
Cyrilvallez May 12, 2025
7031bc4
Update masking_utils.py
Cyrilvallez May 12, 2025
552b586
Update masking_utils.py
Cyrilvallez May 12, 2025
d2fb4de
output attentions
Cyrilvallez May 12, 2025
628dcd8
style
Cyrilvallez May 12, 2025
27fb93f
Update masking_utils.py
Cyrilvallez May 12, 2025
a091517
Update executorch.py
Cyrilvallez May 12, 2025
cbf1144
Add doicstring
Cyrilvallez May 12, 2025
59eb3cc
Add license and put mask visualizer at the end
Cyrilvallez May 12, 2025
85ab5da
Update test_modeling_common.py
Cyrilvallez May 12, 2025
daf5bee
fix broken test
Cyrilvallez May 12, 2025
201da65
Update test_modeling_gemma.py
Cyrilvallez May 12, 2025
f06c7cd
Update test_modeling_gemma2.py
Cyrilvallez May 12, 2025
73a12b4
Use fullgraph=False with FA2
Cyrilvallez May 13, 2025
d0f6f7f
Update utils.py
Cyrilvallez May 13, 2025
5a25046
change name
Cyrilvallez May 13, 2025
cd9461d
Update masking_utils.py
Cyrilvallez May 13, 2025
4169bc3
improve doc
Cyrilvallez May 13, 2025
3166d47
change name
Cyrilvallez May 13, 2025
528aab6
Update modeling_attn_mask_utils.py
Cyrilvallez May 13, 2025
77f2c66
more explicit logic based on model's property
Cyrilvallez May 13, 2025
58d1384
pattern in config
Cyrilvallez May 13, 2025
7a6ac01
extend
Cyrilvallez May 13, 2025
f390675
fixes
Cyrilvallez May 13, 2025
e26eb84
make it better
Cyrilvallez May 13, 2025
bca422c
generalize to other test models
Cyrilvallez May 13, 2025
a66697e
fix
Cyrilvallez May 13, 2025
7ea8db7
Update masking_utils.py
Cyrilvallez May 13, 2025
1d3751f
fix
Cyrilvallez May 13, 2025
df43917
do not check mask equivalence if layer types are different
Cyrilvallez May 14, 2025
095746a
executorch
Cyrilvallez May 14, 2025
770422c
Update modeling_gemma2.py
Cyrilvallez May 14, 2025
0b5a817
Update masking_utils.py
Cyrilvallez May 14, 2025
cf5212c
use layer_idx instead
Cyrilvallez May 14, 2025
e28d663
adjust
Cyrilvallez May 14, 2025
53e9f47
Update masking_utils.py
Cyrilvallez May 14, 2025
8e2bdd1
test
Cyrilvallez May 14, 2025
558c47e
fix imports
Cyrilvallez May 14, 2025
df49780
Update modeling_gemma2.py
Cyrilvallez May 14, 2025
a87f7dd
other test models
Cyrilvallez May 14, 2025
8426b34
Update modeling_llama4.py
Cyrilvallez May 14, 2025
413d446
Update masking_utils.py
Cyrilvallez May 14, 2025
7f0f989
improve
Cyrilvallez May 15, 2025
3ed17a2
simplify
Cyrilvallez May 15, 2025
f23236d
Update masking_utils.py
Cyrilvallez May 15, 2025
0ffff1d
typos
Cyrilvallez May 15, 2025
09d32df
typo
Cyrilvallez May 15, 2025
e20ebab
fix
Cyrilvallez May 15, 2025
d273325
Update masking_utils.py
Cyrilvallez May 15, 2025
5ae049c
default DynamicCache
Cyrilvallez May 15, 2025
326bacf
remove default cache
Cyrilvallez May 15, 2025
d58eaab
simplify
Cyrilvallez May 15, 2025
02a9180
Update masking_utils.py
Cyrilvallez May 15, 2025
d67de19
Update masking_utils.py
Cyrilvallez May 15, 2025
3831ccc
Update masking_utils.py
Cyrilvallez May 15, 2025
4b54f18
Update masking_utils.py
Cyrilvallez May 15, 2025
6edf116
simplify
Cyrilvallez May 15, 2025
18614a5
Update masking_utils.py
Cyrilvallez May 15, 2025
bd931a0
Update masking_utils.py
Cyrilvallez May 15, 2025
93f8d82
Update masking_utils.py
Cyrilvallez May 15, 2025
711ab9b
export
Cyrilvallez May 15, 2025
58f198e
Update executorch.py
Cyrilvallez May 15, 2025
9c69ae5
Update executorch.py
Cyrilvallez May 15, 2025
4e40516
Update flex_attention.py
Cyrilvallez May 15, 2025
6a28a34
Update executorch.py
Cyrilvallez May 15, 2025
c70bf3c
upstream to modular gemma 1 & 2
Cyrilvallez May 15, 2025
3a972d4
Update modular_mistral.py
Cyrilvallez May 15, 2025
7ca132d
switch names
Cyrilvallez May 15, 2025
34a55c5
use dict
Cyrilvallez May 15, 2025
5c89d72
put it in the Layer directly
Cyrilvallez May 15, 2025
e6891b6
update copy model source for mask functions
Cyrilvallez May 15, 2025
ac02170
apply so many modular (hopefully 1 shot)
Cyrilvallez May 15, 2025
59e11ab
use explicite dicts for make style happy
Cyrilvallez May 15, 2025
27041e0
protect import
Cyrilvallez May 15, 2025
0cf18e2
check docstring
Cyrilvallez May 15, 2025
47158df
better default in hybrid caches
Cyrilvallez May 15, 2025
022c4a9
qwens
Cyrilvallez May 16, 2025
94896dc
Update modular_qwen2.py
Cyrilvallez May 16, 2025
9bbe1cb
simplify core logic!
Cyrilvallez May 16, 2025
0844a49
Update executorch.py
Cyrilvallez May 16, 2025
dbbecde
qwen3 moe
Cyrilvallez May 16, 2025
a350263
Update masking_utils.py
Cyrilvallez May 16, 2025
09b0148
Update masking_utils.py
Cyrilvallez May 16, 2025
fcd21a4
simplify a lot sdpa causal skip
Cyrilvallez May 16, 2025
8cb637f
Update masking_utils.py
Cyrilvallez May 16, 2025
481f086
post-rebase
Cyrilvallez May 16, 2025
91c87f8
gemma3 finally
Cyrilvallez May 19, 2025
9bda864
style
Cyrilvallez May 19, 2025
d24309f
check it before
Cyrilvallez May 19, 2025
8e153a1
gemma3
Cyrilvallez May 19, 2025
ebc7f9d
More general with newer torch
Cyrilvallez May 20, 2025
31008ba
align gemma3
Cyrilvallez May 20, 2025
3c385ea
Update utils.py
Cyrilvallez May 20, 2025
b206cd5
Update utils.py
Cyrilvallez May 20, 2025
b0850bf
Update masking_utils.py
Cyrilvallez May 20, 2025
79eac77
Update test_modeling_common.py
Cyrilvallez May 20, 2025
29a6bc2
Update flex_attention.py
Cyrilvallez May 20, 2025
bb2dda0
Update flex_attention.py
Cyrilvallez May 20, 2025
1b85bbb
Update flex_attention.py
Cyrilvallez May 20, 2025
f76df19
test
Cyrilvallez May 20, 2025
3ff3908
executorch
Cyrilvallez May 20, 2025
fd8a6a2
Update test_modeling_common.py
Cyrilvallez May 21, 2025
84db8ee
Update masking_utils.py
Cyrilvallez May 21, 2025
83ba79f
Update masking_utils.py
Cyrilvallez May 21, 2025
acbe4be
Update masking_utils.py
Cyrilvallez May 21, 2025
b0333de
Update masking_utils.py
Cyrilvallez May 21, 2025
3c48334
Update executorch.py
Cyrilvallez May 21, 2025
cfd0694
Update test_modeling_common.py
Cyrilvallez May 21, 2025
0181042
fix copies
Cyrilvallez May 21, 2025
ad5fb36
device
Cyrilvallez May 21, 2025
b477c1e
sdpa can be used without mask -> pass the torchscript tests in this case
Cyrilvallez May 21, 2025
3b71b7b
Use enum for check
Cyrilvallez May 21, 2025
1a05ca1
revert enum and add check instead
Cyrilvallez May 21, 2025
2029cfa
remove broken test
Cyrilvallez May 21, 2025
28d62da
cohere2
Cyrilvallez May 21, 2025
9d7bd3a
some doc & reorganize the Interface
Cyrilvallez May 21, 2025
343ab95
Update tensor_parallel.py
Cyrilvallez May 21, 2025
78a21ea
Update tensor_parallel.py
Cyrilvallez May 21, 2025
4c87caa
doc and dummy
Cyrilvallez May 21, 2025
1f21213
Update test_modeling_paligemma2.py
Cyrilvallez May 21, 2025
e353067
Update modeling_falcon_h1.py
Cyrilvallez May 21, 2025
7979ac6
Update masking_utils.py
Cyrilvallez May 21, 2025
ba6501c
executorch patch
Cyrilvallez May 21, 2025
269969e
style
Cyrilvallez May 21, 2025
75ccf7a
CIs
Cyrilvallez May 21, 2025
7bcd55f
use register in executorch
Cyrilvallez May 22, 2025
9245fcd
final comments!
Cyrilvallez May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion docs/source/en/attention_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,44 @@ would expect from a usual Python dictionary:

# You can also globally `register` a new function directly on it
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
```
```

## Attention Mask Interface

Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
the `AttentionInterface`:

```python
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
import torch

def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)

Comment on lines +141 to +144
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rather show how to do something like the paligemma or document masking here, something relevant!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are a bit different, it's modifying the mask pattern vs adding a new mask format for the attention itself (both are complementary)

AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
```

The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
and `attention_mask=None` will be passed along to the Attention layers.

The default signature of the attention mask functions is the following:

```python
def custom_attention_mask(
batch_size: int, # required arg
cache_position: torch.Tensor, # required arg
kv_length: int, # required arg
kv_offset: int = 0, # required arg
mask_function: Callable = causal_mask_function, # required arg
attention_mask: Optional[torch.Tensor] = None, # required arg
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:
```

It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.

If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).
5 changes: 5 additions & 0 deletions docs/source/en/internal/modeling_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
[[autodoc]] AttentionInterface
- register

## Attention Mask Functions

[[autodoc]] AttentionMaskInterface
- register

## Rotary Position Embedding Functions

[[autodoc]] dynamic_rope_update
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
_import_structure["optimization"] = [
"Adafactor",
"get_constant_schedule",
Expand Down Expand Up @@ -914,6 +915,7 @@
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)
from .masking_utils import AttentionMaskInterface
from .model_debugging_utils import (
model_addition_debugger_context,
)
Expand Down
119 changes: 103 additions & 16 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ def seen_tokens(self):
else:
return None

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
return kv_length, 0


@dataclass
class CacheConfig:
Expand Down Expand Up @@ -1084,8 +1096,6 @@ class SinkCache(Cache):
```
"""

is_sliding = True

def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
Expand Down Expand Up @@ -1390,6 +1400,16 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
kv_length = self.get_max_cache_shape()
return kv_length, 0


class SlidingWindowCache(StaticCache):
"""
Expand Down Expand Up @@ -1446,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
```
"""

is_sliding = True
is_compileable = True

def __init__(
Expand All @@ -1465,6 +1484,7 @@ def __init__(
"config and it's not set to None."
)
max_cache_len = min(config.sliding_window, max_cache_len)
self.sliding_window = config.sliding_window
super().__init__(
config=config,
max_batch_size=max_batch_size,
Expand Down Expand Up @@ -1509,6 +1529,21 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
# torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
kv_length = max(query_length, self.get_max_cache_shape())
return kv_length, kv_offset


class EncoderDecoderCache(Cache):
"""
Expand Down Expand Up @@ -1761,12 +1796,17 @@ def __init__(
else config.num_key_value_heads
)

layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
else:
self.is_sliding = [False] * config.num_hidden_layers

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
self.sliding_window = min(config.sliding_window, max_cache_len)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
Expand All @@ -1775,7 +1815,7 @@ def __init__(
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
Expand All @@ -1796,7 +1836,7 @@ def update(
if cache_position is None:
raise ValueError("`cache_position` must be provided for HybridCache.")

is_sliding_layer = self.is_sliding_list[layer_idx]
is_sliding_layer = self.is_sliding[layer_idx]

# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
Expand Down Expand Up @@ -1843,6 +1883,26 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
if self.is_sliding[layer_idx]:
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]

local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
local_mask_kv_length = max(query_length, self.sliding_window)
return local_mask_kv_length, local_mask_kv_offset

full_mask_kv_offset = 0
full_mask_kv_length = self.get_max_cache_shape()
return full_mask_kv_length, full_mask_kv_offset


class HybridChunkedCache(Cache):
"""
Expand Down Expand Up @@ -1912,11 +1972,11 @@ def __init__(
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype

if hasattr(config.get_text_config(), "no_rope_layers"):
self.is_sliding = config.no_rope_layers
# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
else:
layer_switch = getattr(config, "sliding_window_pattern", 2)
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.is_sliding = [False] * config.num_hidden_layers

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
Expand Down Expand Up @@ -1999,11 +2059,7 @@ def update(
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

if self.is_sliding[layer_idx]:
update_fn = self._sliding_update
else:
update_fn = self._static_update

update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update
return update_fn(
cache_position,
layer_idx,
Expand Down Expand Up @@ -2038,6 +2094,37 @@ def reset(self):
self.value_cache[layer_idx].zero_()
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
if self.is_sliding[layer_idx]:
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]

local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is the true general case for any Cache using local attention (sliding or chunked)
if first_cache_position >= self.sliding_window:
# Here the Cache is already full
local_mask_kv_length = self.sliding_window + query_length - 1
elif (
first_cache_position < self.sliding_window
and first_cache_position + query_length > self.sliding_window
):
# Here the Cache becomes full with the new input
local_mask_kv_length = first_cache_position + query_length
else:
# Here the Cache is still smaller than the local size, but we return the local size as it's static
local_mask_kv_length = self.sliding_window
return local_mask_kv_length, local_mask_kv_offset

full_mask_kv_offset = 0
full_mask_kv_length = self.get_max_cache_shape()
return full_mask_kv_length, full_mask_kv_offset


class OffloadedHybridCache(HybridChunkedCache):
def __init__(
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,3 +1209,16 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
object="config", object_class="AutoConfig", object_files="configuration file"
)


ALLOWED_LAYER_TYPES = (
"full_attention",
"sliding_attention",
"chunked_attention",
)


def layer_type_validation(layer_types: list[str]):
"""Check that each entry in `layer_types` are allowed."""
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
35 changes: 30 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..masking_utils import create_masks_for_generate
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
Expand Down Expand Up @@ -74,6 +75,7 @@
from .configuration_utils import (
NEED_SETUP_CACHE_CLASSES_MAPPING,
QUANT_BACKEND_CLASSES_MAPPING,
CompileConfig,
GenerationConfig,
GenerationMode,
)
Expand Down Expand Up @@ -649,12 +651,22 @@ def prepare_inputs_for_generation(
causal_mask_creation_function = getattr(
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
)

# If it's not defined, it means the model uses the new general mask API
if causal_mask_creation_function is None: # can't be found
logger.warning_once(
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
"writing code, see Llama for an example implementation. If you're a user, please report this "
"issue on GitHub."
output_attentions = kwargs.get("output_attentions", False)
token_type_ids = getattr(model_input, "token_type_ids", None)
# Some models may overwrite the general one
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
attention_mask = causal_mask_creation_function(
config=self.config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
token_type_ids=token_type_ids,
)
else:
attention_mask = causal_mask_creation_function(
Expand Down Expand Up @@ -3539,6 +3551,19 @@ def _sample(
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2" and getattr(
model_kwargs.get("past_key_values"), "is_compileable", False
):
if generation_config.compile_config is None:
generation_config.compile_config = CompileConfig(fullgraph=False)
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
elif generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)

if generation_config.prefill_chunk_size is not None:
Expand Down
Loading