Skip to content

Text Localization with Blip2 #549

Open
@dip9811111

Description

Starting from the tutorial link and considering the function compute_gradcam in BlipITM link I'm trying to obtain the same result but using Blip2ITM. Function getAttMap is at link.

This is my code:

def compute_gradcam_new(model, visual_input, text_input, tokenized_text, block_num=None)    
    target_layer = model.Qformer.bert.encoder.layer[block_num].crossattention 
    target_layer.self.save_attention = True

    output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
    loss = output[:, 1].sum()

    model.zero_grad()
    loss.backward()
    
    with torch.no_grad():
        mask = tokenized_text.attention_mask.view(
            tokenized_text.attention_mask.size(0), 1, -1, 1, 1
        ) 

        token_length = tokenized_text.attention_mask.sum(dim=-1) - 2
        token_length = token_length.cpu()

        grads = target_layer.self.get_attn_gradients()
        cams = target_layer.self.get_attention_map()
        
        cams = cams[:, :,:mask.shape[2], 1:].reshape(visual_input.size(0), 12, -1, 16, 16) * mask
        grads = grads[:, :, :mask.shape[2], 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 16, 16) * mask

        gradcams = cams * grads
        gradcam = gradcams[0].mean(0).cpu().detach()
        
        text_input = tokenized_text
        rgb_image = cv2.imread(image_path)[:, :, ::-1]
        rgb_image = np.float32(rgb_image) / 255

        folder_path_images = ".../folderImages"
        
        for i, token_id in enumerate(text_input.input_ids[0][:]):
            word = tokenizer.decode([token_id])
            word = word.replace("##", "")
            gradcam_image = getAttMap(rgb_image, gradcam[i])
            fig_, ax_ = plt.subplots(1, 1, figsize=(15,5))
            ax_.imshow(gradcam_image)
            ax_.set_yticks([])
            ax_.set_xticks([])
            ax_.set_xlabel(word)
            path_save_image = f"{folder_path_images}/{i}.png"
            fig_.savefig(path_save_image, bbox_inches='tight')


device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model_original, image_processor, text_processor = load_model_and_preprocess("blip2_image_text_matching", "pretrain", 
                                                                   device=device, 
                                                                   is_eval=True,
                                                                  )

image_path = ".../tryImage.jpg"
image = Image.open(image_path).convert('RGB') 

visual_input = torch.stack([image_processor['eval'](image)]).to(device)
text_input = "The sun shines on the colosseum in rome"
text_input = text_processor["eval"](text_input)

output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
tokenized_text = tokenizer(text_input, return_tensors="pt").to(device)
compute_gradcam_new(model, visual_input, text_input, tokenized_text, block_num=10)

Where I considered as target layer model.Qformer.bert.encoder.layer[10]. What I got is different from BlipITM is that cams and grads have a dynamical shape [1, 12, N, 577], where N is the number of tokens of the input text.

Instead, in Blip2ITM the QFormer appears to be instantiated with num_query_token=32. So now grads and cams are always in the form of [1, 12, 32, 257].

For example using that input text, I got:

cams.shape  = torch.Size([1, 12, 32, 257])
grads.shape = torch.Size([1, 12, 32, 257])
mask.shape = torch.Size([1, 1, 8, 1, 1])

So to multiply grads * cams * mask I tried to consider only the first N (mask.shape[2]):

cams = cams[:, :,:mask.shape[2], 1:].reshape(visual_input.size(0), 12, -1, 16, 16) * mask
grads = grads[:, :, :mask.shape[2], 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 16, 16) * mask

Doing this I got no error but the Grad-CAM is awful and doesn't make sense at all.
What's wrong with this?

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions