Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions brainscore_vision/models/CLIP_RN50/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from brainscore_vision import model_registry
from .model import get_layers,get_model


model_registry['CLIP-RN50'] = \
lambda: ModelCommitment(identifier='CLIP-RN50', activations_model=get_model('CLIP-RN50'), layers=get_layers('CLIP-RN50'))
47 changes: 47 additions & 0 deletions brainscore_vision/models/CLIP_RN50/helpers/clip_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import clip
import functools
import torch
from .imagenet_class_names import imagenet_class_names
from brainscore_vision.model_helpers.activations.pytorch import load_images


def _load_and_preprocess(img, process_function):
images = load_images(img)
images = [process_function(image).numpy() for image in images]
return images


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class CosineSimilarityLayer(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return torch.matmul(a, b.T).T


class ClipModel(torch.nn.Module):
def __init__(self, architecture):
super().__init__()

clmodel, preprocess = clip.load(architecture)
self.clmodel = clmodel.eval().to(DEVICE)
self.preprocessing = functools.partial(_load_and_preprocess, process_function=preprocess)

text_descriptions = ["A photo of a " + label for label in imagenet_class_names]
text_tokens = clip.tokenize(text_descriptions).to(DEVICE)
with torch.no_grad():
self.text_features = self.clmodel.encode_text(text_tokens).float()
self.text_features /= self.text_features.norm(dim=-1, keepdim=True)

self.logits = CosineSimilarityLayer()

def forward(self, img):
with torch.no_grad():
image_features = self.clmodel.encode_image(img).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
return self.logits(self.text_features, image_features)


Loading