36
36
WHITE_SQUARE = "⬚"
37
37
38
38
39
- def generate_attention_matrix_from_mask (words , mask , img_token = "<img>" , sliding_window = None , token_type_ids = None ):
39
+ def generate_attention_matrix_from_mask (
40
+ words , mask , img_token = "<img>" , sliding_window = None , token_type_ids = None , image_seq_length = None
41
+ ):
40
42
"""
41
43
Generates an attention matrix from a given attention mask.
42
44
@@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
80
82
for j in range (n )
81
83
)
82
84
85
+ if token_type_ids is not None :
86
+ is_special = token_type_ids == 1
87
+ token_type_buckets = torch .where (
88
+ (token_type_ids .cumsum (- 1 ) % 5 + is_special ).bool (), token_type_ids .cumsum (- 1 ), 0
89
+ )
90
+ boundaries = torch .arange (0 , image_seq_length + 1 , image_seq_length )
91
+ token_type_buckets = torch .bucketize (token_type_buckets , boundaries = boundaries )
92
+
83
93
# Print headers
84
94
legend = f"{ GREEN } { BLACK_SQUARE } { RESET } : i == j (diagonal) { YELLOW } { BLACK_SQUARE } { RESET } : token_type_ids"
85
95
output .append (" " + legend )
@@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
103
113
if sliding_window is not None
104
114
else ""
105
115
)
106
-
107
116
for i , word in enumerate (words ):
108
117
word_repr = repr (word ).ljust (max_word_length )
109
118
colored_word = f"{ YELLOW } { word_repr } { RESET } " if img_token in word else word_repr
@@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
121
130
if sliding_window is not None :
122
131
sliding_window_row = " " .join (
123
132
f"{ YELLOW } { BLACK_SQUARE } { RESET } "
124
- if img_token in words [j ] and img_token in words [i ]
133
+ if img_token in words [j ]
134
+ and img_token in words [i ]
135
+ and token_type_buckets [0 , i ] == token_type_buckets [0 , j ]
125
136
else f"{ GREEN } { BLACK_SQUARE } { RESET } "
126
137
if i == j
127
138
else BLACK_SQUARE
@@ -170,7 +181,8 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
170
181
if self .config .model_type in PROCESSOR_MAPPING_NAMES :
171
182
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
172
183
img = Image .open (requests .get (img , stream = True ).raw )
173
- processor = AutoProcessor .from_pretrained (self .repo_id , image_seq_length = 5 )
184
+ image_seq_length = 5
185
+ processor = AutoProcessor .from_pretrained (self .repo_id , image_seq_length = image_seq_length )
174
186
if hasattr (processor , "image_token" ):
175
187
image_token = processor .image_token
176
188
else :
@@ -179,7 +191,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
179
191
if image_token :
180
192
input_sentence = input_sentence .replace ("<img>" , image_token )
181
193
182
- inputs = processor (img , input_sentence , suffix = suffix , return_tensors = "pt" )
194
+ inputs = processor (images = img , text = input_sentence , suffix = suffix , return_tensors = "pt" )
183
195
184
196
self .image_token = processor .tokenizer .convert_ids_to_tokens ([processor .image_token_id ])[0 ]
185
197
@@ -223,6 +235,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
223
235
img_token = self .image_token ,
224
236
sliding_window = getattr (self .config , "sliding_window" , None ),
225
237
token_type_ids = kwargs .get ("token_type_ids" , None ),
238
+ image_seq_length = image_seq_length ,
226
239
)
227
240
print (f_string )
228
241
print (f"{ top_bottom_border } " )
0 commit comments