Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def forward(self, vision_outputs: torch.Tensor):
return projected_vision_outputs.type_as(vision_outputs)


def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
Expand All @@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
return None

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
# If the difference is less than image size, both are part of the same image block
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
# If it's 1 for both query and key/value, we are in an image block
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)

# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block

return inner_mask

Expand Down Expand Up @@ -945,7 +950,7 @@ def forward(
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
)

# Create the masks
Expand Down Expand Up @@ -1211,7 +1216,9 @@ def create_masks_for_generate(
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
)

return create_masks_for_generate(**mask_kwargs)

Expand Down
17 changes: 12 additions & 5 deletions src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def forward(self, vision_outputs: torch.Tensor):
return projected_vision_outputs.type_as(vision_outputs)


def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
Expand All @@ -732,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
return None

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
# If the difference is less than image size, both are part of the same image block
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
# If it's 1 for both query and key/value, we are in an image block
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)

# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block

return inner_mask

Expand Down Expand Up @@ -836,7 +841,7 @@ def forward(
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
)

# Create the masks
Expand Down Expand Up @@ -1055,7 +1060,9 @@ def create_masks_for_generate(
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
)

return create_masks_for_generate(**mask_kwargs)

Expand Down