Description
Starting from the tutorial link and considering the function compute_gradcam in BlipITM link I'm trying to obtain the same result but using Blip2ITM. Function getAttMap is at link.
This is my code:
def compute_gradcam_new(model, visual_input, text_input, tokenized_text, block_num=None)
target_layer = model.Qformer.bert.encoder.layer[block_num].crossattention
target_layer.self.save_attention = True
output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
loss = output[:, 1].sum()
model.zero_grad()
loss.backward()
with torch.no_grad():
mask = tokenized_text.attention_mask.view(
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
)
token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
token_length = token_length.cpu()
grads = target_layer.self.get_attn_gradients()
cams = target_layer.self.get_attention_map()
cams = cams[:, :,:mask.shape[2], 1:].reshape(visual_input.size(0), 12, -1, 16, 16) * mask
grads = grads[:, :, :mask.shape[2], 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 16, 16) * mask
gradcams = cams * grads
gradcam = gradcams[0].mean(0).cpu().detach()
text_input = tokenized_text
rgb_image = cv2.imread(image_path)[:, :, ::-1]
rgb_image = np.float32(rgb_image) / 255
folder_path_images = ".../folderImages"
for i, token_id in enumerate(text_input.input_ids[0][:]):
word = tokenizer.decode([token_id])
word = word.replace("##", "")
gradcam_image = getAttMap(rgb_image, gradcam[i])
fig_, ax_ = plt.subplots(1, 1, figsize=(15,5))
ax_.imshow(gradcam_image)
ax_.set_yticks([])
ax_.set_xticks([])
ax_.set_xlabel(word)
path_save_image = f"{folder_path_images}/{i}.png"
fig_.savefig(path_save_image, bbox_inches='tight')
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model_original, image_processor, text_processor = load_model_and_preprocess("blip2_image_text_matching", "pretrain",
device=device,
is_eval=True,
)
image_path = ".../tryImage.jpg"
image = Image.open(image_path).convert('RGB')
visual_input = torch.stack([image_processor['eval'](image)]).to(device)
text_input = "The sun shines on the colosseum in rome"
text_input = text_processor["eval"](text_input)
output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
tokenized_text = tokenizer(text_input, return_tensors="pt").to(device)
compute_gradcam_new(model, visual_input, text_input, tokenized_text, block_num=10)
Where I considered as target layer model.Qformer.bert.encoder.layer[10]. What I got is different from BlipITM is that cams and grads have a dynamical shape [1, 12, N, 577], where N is the number of tokens of the input text.
Instead, in Blip2ITM the QFormer appears to be instantiated with num_query_token=32. So now grads and cams are always in the form of [1, 12, 32, 257].
For example using that input text, I got:
cams.shape = torch.Size([1, 12, 32, 257])
grads.shape = torch.Size([1, 12, 32, 257])
mask.shape = torch.Size([1, 1, 8, 1, 1])
So to multiply grads * cams * mask I tried to consider only the first N (mask.shape[2]):
cams = cams[:, :,:mask.shape[2], 1:].reshape(visual_input.size(0), 12, -1, 16, 16) * mask
grads = grads[:, :, :mask.shape[2], 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 16, 16) * mask
Doing this I got no error but the Grad-CAM is awful and doesn't make sense at all.
What's wrong with this?