Skip to content

Commit b5ececb

Browse files
authored
Fix image token mask in Gemma3 (#38295)
fix mask
1 parent c4e71e8 commit b5ececb

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ def forward(self, vision_outputs: torch.Tensor):
782782
return projected_vision_outputs.type_as(vision_outputs)
783783

784784

785-
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
785+
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
786786
"""
787787
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
788788
not start and end indices.
@@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
792792
return None
793793

794794
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
795-
# If it's 1, we need to unmask it
796-
return token_type_ids[batch_idx, kv_idx] == 1
795+
# If the difference is less than image size, both are part of the same image block
796+
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
797+
# If it's 1 for both query and key/value, we are in an image block
798+
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
799+
800+
# This is bidirectional attention whenever we are dealing with image tokens
801+
return is_image_block & same_image_block
797802

798803
return inner_mask
799804

@@ -945,7 +950,7 @@ def forward(
945950
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
946951
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
947952
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
948-
token_type_ids.to(cache_position.device)
953+
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
949954
)
950955

951956
# Create the masks
@@ -1211,7 +1216,9 @@ def create_masks_for_generate(
12111216
# Add the token type ids mask for generate as well
12121217
if token_type_ids is not None and input_embeds.shape[1] != 1:
12131218
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1214-
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
1219+
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1220+
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
1221+
)
12151222

12161223
return create_masks_for_generate(**mask_kwargs)
12171224

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def forward(self, vision_outputs: torch.Tensor):
722722
return projected_vision_outputs.type_as(vision_outputs)
723723

724724

725-
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
725+
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
726726
"""
727727
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
728728
not start and end indices.
@@ -732,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
732732
return None
733733

734734
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
735-
# If it's 1, we need to unmask it
736-
return token_type_ids[batch_idx, kv_idx] == 1
735+
# If the difference is less than image size, both are part of the same image block
736+
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
737+
# If it's 1 for both query and key/value, we are in an image block
738+
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
739+
740+
# This is bidirectional attention whenever we are dealing with image tokens
741+
return is_image_block & same_image_block
737742

738743
return inner_mask
739744

@@ -836,7 +841,7 @@ def forward(
836841
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
837842
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
838843
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
839-
token_type_ids.to(cache_position.device)
844+
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
840845
)
841846

842847
# Create the masks
@@ -1055,7 +1060,9 @@ def create_masks_for_generate(
10551060
# Add the token type ids mask for generate as well
10561061
if token_type_ids is not None and input_embeds.shape[1] != 1:
10571062
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1058-
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
1063+
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1064+
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
1065+
)
10591066

10601067
return create_masks_for_generate(**mask_kwargs)
10611068

0 commit comments

Comments
 (0)