@@ -782,7 +782,7 @@ def forward(self, vision_outputs: torch.Tensor):
782782 return projected_vision_outputs .type_as (vision_outputs )
783783
784784
785- def token_type_ids_mask_function (token_type_ids : Optional [torch .Tensor ]) -> Optional [Callable ]:
785+ def token_type_ids_mask_function (token_type_ids : Optional [torch .Tensor ], tokens_per_image : int ) -> Optional [Callable ]:
786786 """
787787 This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
788788 not start and end indices.
@@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
792792 return None
793793
794794 def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
795- # If it's 1, we need to unmask it
796- return token_type_ids [batch_idx , kv_idx ] == 1
795+ # If the difference is less than image size, both are part of the same image block
796+ same_image_block = torch .abs (kv_idx - q_idx ) <= tokens_per_image
797+ # If it's 1 for both query and key/value, we are in an image block
798+ is_image_block = (token_type_ids [batch_idx , q_idx ] == 1 ) & (token_type_ids [batch_idx , kv_idx ] == 1 )
799+
800+ # This is bidirectional attention whenever we are dealing with image tokens
801+ return is_image_block & same_image_block
797802
798803 return inner_mask
799804
@@ -945,7 +950,7 @@ def forward(
945950 if token_type_ids is not None and inputs_embeds .shape [1 ] != 1 :
946951 # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
947952 mask_kwargs ["or_mask_function" ] = token_type_ids_mask_function (
948- token_type_ids .to (cache_position .device )
953+ token_type_ids .to (cache_position .device ), self . config . mm_tokens_per_image
949954 )
950955
951956 # Create the masks
@@ -1211,7 +1216,9 @@ def create_masks_for_generate(
12111216 # Add the token type ids mask for generate as well
12121217 if token_type_ids is not None and input_embeds .shape [1 ] != 1 :
12131218 # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1214- mask_kwargs ["or_mask_function" ] = token_type_ids_mask_function (token_type_ids .to (cache_position .device ))
1219+ mask_kwargs ["or_mask_function" ] = token_type_ids_mask_function (
1220+ token_type_ids .to (cache_position .device ), config .mm_tokens_per_image
1221+ )
12151222
12161223 return create_masks_for_generate (** mask_kwargs )
12171224
0 commit comments