Skip to content

Commit f834d36

Browse files
authored
[gemma3] fix bidirectional attention mask (#38080)
* fix attn mask * attn viz doesn't show yello cubes between images * bucketize made it hard with different number of crops * fixup
1 parent 2edb0e4 commit f834d36

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,10 +1062,21 @@ def _update_causal_mask(
10621062
if token_type_ids is not None and sequence_length != 1:
10631063
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
10641064
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
1065-
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
1065+
1066+
# Find where a new image block starts: 1 if image and previous not image
1067+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
1068+
is_image = token_type_ids == 1
1069+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
1070+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
1071+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
1072+
1073+
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
1074+
same_image_mask[image_group_ids == -1] = False # remove non-image
1075+
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
1076+
10661077
causal_mask = causal_mask.clone()
10671078
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
1068-
token_type_mask, 0.0
1079+
image_mask, 0.0
10691080
)
10701081

10711082
if attention_mask is not None:

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,10 +781,21 @@ def _update_causal_mask(
781781
if token_type_ids is not None and sequence_length != 1:
782782
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
783783
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
784-
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
784+
785+
# Find where a new image block starts: 1 if image and previous not image
786+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
787+
is_image = token_type_ids == 1
788+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
789+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
790+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
791+
792+
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
793+
same_image_mask[image_group_ids == -1] = False # remove non-image
794+
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
795+
785796
causal_mask = causal_mask.clone()
786797
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
787-
token_type_mask, 0.0
798+
image_mask, 0.0
788799
)
789800

790801
if attention_mask is not None:

src/transformers/utils/attention_visualizer.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
WHITE_SQUARE = "⬚"
3737

3838

39-
def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_window=None, token_type_ids=None):
39+
def generate_attention_matrix_from_mask(
40+
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
41+
):
4042
"""
4143
Generates an attention matrix from a given attention mask.
4244
@@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
8082
for j in range(n)
8183
)
8284

85+
if token_type_ids is not None:
86+
is_special = token_type_ids == 1
87+
token_type_buckets = torch.where(
88+
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
89+
)
90+
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
91+
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
92+
8393
# Print headers
8494
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
8595
output.append(" " + legend)
@@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
103113
if sliding_window is not None
104114
else ""
105115
)
106-
107116
for i, word in enumerate(words):
108117
word_repr = repr(word).ljust(max_word_length)
109118
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
@@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
121130
if sliding_window is not None:
122131
sliding_window_row = " ".join(
123132
f"{YELLOW}{BLACK_SQUARE}{RESET}"
124-
if img_token in words[j] and img_token in words[i]
133+
if img_token in words[j]
134+
and img_token in words[i]
135+
and token_type_buckets[0, i] == token_type_buckets[0, j]
125136
else f"{GREEN}{BLACK_SQUARE}{RESET}"
126137
if i == j
127138
else BLACK_SQUARE
@@ -170,7 +181,8 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
170181
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
171182
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
172183
img = Image.open(requests.get(img, stream=True).raw)
173-
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
184+
image_seq_length = 5
185+
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
174186
if hasattr(processor, "image_token"):
175187
image_token = processor.image_token
176188
else:
@@ -179,7 +191,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
179191
if image_token:
180192
input_sentence = input_sentence.replace("<img>", image_token)
181193

182-
inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
194+
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
183195

184196
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
185197

@@ -223,6 +235,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
223235
img_token=self.image_token,
224236
sliding_window=getattr(self.config, "sliding_window", None),
225237
token_type_ids=kwargs.get("token_type_ids", None),
238+
image_seq_length=image_seq_length,
226239
)
227240
print(f_string)
228241
print(f"{top_bottom_border}")

0 commit comments

Comments
 (0)