Skip to content

Commit

Permalink
replicate demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Chenxi authored and Chenxi committed Feb 6, 2022
1 parent 25511dc commit b520dc7
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The demo includes code for:
3. Multimodal / unimodal feature extraction
4. Image-text matching

Replicate web demo and Docker image is available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)

Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/BLIP)

### Pre-trained checkpoints:
Expand Down
17 changes: 17 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
build:
gpu: true
cuda: "11.1"
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==7.30.1"
- "torchvision==0.11.1"
- "torch==1.10.0"
- "timm==0.4.12"
- "transformers==4.15.0"
- "fairscale==0.4.4"
- "pycocoevalcap==1.2"

predict: "predict.py:Predictor"
98 changes: 98 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Download the weights in ./checkpoints beforehand for fast inference
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
"""

from pathlib import Path

from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import cog

from models.blip import blip_decoder
from models.blip_vqa import blip_vqa
from models.blip_itm import blip_itm


class Predictor(cog.Predictor):
def setup(self):
self.device = "cuda:0"

self.models = {
'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
image_size=384, vit='base'),
'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
image_size=480, vit='base'),
'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
image_size=384, vit='base')
}

@cog.input(
"image",
type=Path,
help="input image",
)
@cog.input(
"task",
type=str,
default='image_captioning',
options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
help="Choose a task.",
)
@cog.input(
"question",
type=str,
default=None,
help="Type question for the input image for visual question answering task.",
)
@cog.input(
"caption",
type=str,
default=None,
help="Type caption for the input image for image text matching task.",
)
def predict(self, image, task, question, caption):
if task == 'visual_question_answering':
assert question is not None, 'Please type a question for visual question answering task.'
if task == 'image_text_matching':
assert caption is not None, 'Please type a caption for mage text matching task.'

im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
model = self.models[task]
model.eval()
model = model.to(self.device)

if task == 'image_captioning':
with torch.no_grad():
caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
return 'Caption: ' + caption[0]

if task == 'visual_question_answering':
with torch.no_grad():
answer = model(im, question, train=False, inference='generate')
return 'Answer: ' + answer[0]

# image_text_matching
itm_output = model(im, caption, match_head='itm')
itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
itc_score = model(im, caption, match_head='itc')
return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'


def load_image(image, image_size, device):
raw_image = Image.open(str(image)).convert('RGB')

w, h = raw_image.size

transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(raw_image).unsqueeze(0).to(device)
return image

0 comments on commit b520dc7

Please sign in to comment.