Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,10 +1062,21 @@ def _update_causal_mask(
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)

# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))

same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)

causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
image_mask, 0.0
)

if attention_mask is not None:
Expand Down
15 changes: 13 additions & 2 deletions src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,21 @@ def _update_causal_mask(
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)

# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))

same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)

causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
image_mask, 0.0
)

if attention_mask is not None:
Expand Down
23 changes: 18 additions & 5 deletions src/transformers/utils/attention_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
WHITE_SQUARE = "⬚"


def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_window=None, token_type_ids=None):
def generate_attention_matrix_from_mask(
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
):
"""
Generates an attention matrix from a given attention mask.

Expand Down Expand Up @@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
for j in range(n)
)

if token_type_ids is not None:
is_special = token_type_ids == 1
token_type_buckets = torch.where(
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
)
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)

# Print headers
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
output.append(" " + legend)
Expand All @@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
if sliding_window is not None
else ""
)

for i, word in enumerate(words):
word_repr = repr(word).ljust(max_word_length)
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
Expand All @@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
if sliding_window is not None:
sliding_window_row = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
if img_token in words[j] and img_token in words[i]
if img_token in words[j]
and img_token in words[i]
and token_type_buckets[0, i] == token_type_buckets[0, j]
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if i == j
else BLACK_SQUARE
Expand Down Expand Up @@ -170,7 +181,8 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
img = Image.open(requests.get(img, stream=True).raw)
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
image_seq_length = 5
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
if hasattr(processor, "image_token"):
image_token = processor.image_token
else:
Expand All @@ -179,7 +191,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
if image_token:
input_sentence = input_sentence.replace("<img>", image_token)

inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")

self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]

Expand Down Expand Up @@ -223,6 +235,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
img_token=self.image_token,
sliding_window=getattr(self.config, "sliding_window", None),
token_type_ids=kwargs.get("token_type_ids", None),
image_seq_length=image_seq_length,
)
print(f_string)
print(f"{top_bottom_border}")