Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions brainscore_vision/models/BiT_S_R101x1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from brainscore_vision import model_registry
from .model import get_layers,get_model


model_registry['BiT-S-R101x1'] = \
lambda: ModelCommitment(identifier='BiT-S-R101x1', activations_model=get_model('BiT-S-R101x1'), layers=get_layers('BiT-S-R101x1'))
223 changes: 223 additions & 0 deletions brainscore_vision/models/BiT_S_R101x1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from collections import OrderedDict # pylint: disable=g-importing-member
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.check_submission import check_models
import requests
import numpy as np
import io

class StdConv2d(nn.Conv2d):

def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-10)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)


def conv3x3(cin, cout, stride=1, groups=1, bias=False):
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
padding=1, bias=bias, groups=groups)


def conv1x1(cin, cout, stride=1, bias=False):
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
padding=0, bias=bias)


def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)


class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.

Follows the implementation of "Identity Mappings in Deep Residual Networks":
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

Except it puts the stride on 3x3 conv when available.
"""

def __init__(self, cin, cout=None, cmid=None, stride=1):
super().__init__()
cout = cout or cin
cmid = cmid or cout//4

self.gn1 = nn.GroupNorm(32, cin)
self.conv1 = conv1x1(cin, cmid)
self.gn2 = nn.GroupNorm(32, cmid)
self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
self.gn3 = nn.GroupNorm(32, cmid)
self.conv3 = conv1x1(cmid, cout)
self.relu = nn.ReLU(inplace=True)

if (stride != 1 or cin != cout):
# Projection also with pre-activation according to paper.
self.downsample = conv1x1(cin, cout, stride)

def forward(self, x):
out = self.relu(self.gn1(x))

# Residual branch
residual = x
if hasattr(self, 'downsample'):
residual = self.downsample(out)

# Unit's branch
out = self.conv1(out)
out = self.conv2(self.relu(self.gn2(out)))
out = self.conv3(self.relu(self.gn3(out)))

return out + residual

def load_from(self, weights, prefix=''):
convname = 'standardized_conv2d'
with torch.no_grad():
self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel']))
self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel']))
self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel']))
self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma']))
self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma']))
self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma']))
self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta']))
self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta']))
self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta']))
if hasattr(self, 'downsample'):
w = weights[f'{prefix}a/proj/{convname}/kernel']
self.downsample.weight.copy_(tf2th(w))


class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode."""

def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
super().__init__()
wf = width_factor # shortcut 'cause we'll use it a lot.

# The following will be unreadable if we split lines.
# pylint: disable=line-too-long
self.root = nn.Sequential(OrderedDict([
('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
('pad', nn.ConstantPad2d(1, 0)),
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
# The following is subtly not the same!
# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))

self.body = nn.Sequential(OrderedDict([
('block1', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] +
[(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
))),
('block2', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
[(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
))),
('block3', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
[(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
))),
('block4', nn.Sequential(OrderedDict(
[('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
[(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
))),
]))
# pylint: enable=line-too-long

self.zero_head = zero_head
self.head = nn.Sequential(OrderedDict([
('gn', nn.GroupNorm(32, 2048*wf)),
('relu', nn.ReLU(inplace=True)),
('avg', nn.AdaptiveAvgPool2d(output_size=1)),
('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
]))

def forward(self, x):
x = self.head(self.body(self.root(x)))
assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
return x[...,0,0]

def load_from(self, weights, prefix='resnet/'):
with torch.no_grad():
self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
if self.zero_head:
nn.init.zeros_(self.head.conv.weight)
nn.init.zeros_(self.head.conv.bias)
else:
self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long
self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))

for bname, block in self.body.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')


KNOWN_MODELS = OrderedDict([
('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
])
ALL_MODELS = list(KNOWN_MODELS.keys())

R101_LAYERS = [
'body.block1.unit01.relu', 'body.block1.unit02.relu',
'body.block1.unit03.relu', 'body.block2.unit01.relu',
'body.block2.unit02.relu', 'body.block2.unit03.relu',
'body.block2.unit04.relu', 'body.block3.unit01.relu',
'body.block3.unit02.relu', 'body.block3.unit03.relu',
'body.block3.unit04.relu', 'body.block3.unit05.relu',
'body.block3.unit06.relu', 'body.block3.unit07.relu',
'body.block3.unit08.relu', 'body.block3.unit09.relu',
'body.block3.unit10.relu', 'body.block3.unit11.relu',
'body.block3.unit12.relu', 'body.block3.unit13.relu',
'body.block3.unit14.relu', 'body.block3.unit15.relu',
'body.block3.unit16.relu', 'body.block3.unit17.relu',
'body.block3.unit18.relu', 'body.block3.unit19.relu',
'body.block3.unit20.relu', 'body.block3.unit21.relu',
'body.block3.unit22.relu', 'body.block3.unit23.relu',
'body.block4.unit01.relu', 'body.block4.unit02.relu',
'body.block4.unit03.relu'
]

def get_weights(bit_variant):
response = requests.get(f'https://storage.googleapis.com/bit_models/{bit_variant}.npz')
response.raise_for_status()
return np.load(io.BytesIO(response.content))

def get_model(name):
assert name in ALL_MODELS
model = KNOWN_MODELS[name](head_size=1000) # Small BiTs are pretrained on ImageNet
weights = get_weights(name)
model.load_from(weights)
model.eval()
image_size = 224
preprocessing = functools.partial(load_preprocess_images, image_size=image_size)
wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)
wrapper.image_size = image_size
return wrapper

def get_layers(name):
assert name in ALL_MODELS
return R101_LAYERS

def get_bibtex(model_identifier):
return """@article{touvron2020deit,
title={Training data-efficient image transformers & distillation through attention},
author={Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Herv\'e J\'egou},
journal={arXiv preprint arXiv:2012.12877},
year={2020}
}"""


if __name__ == '__main__':
# Use this method to ensure the correctness of the BaseModel implementations.
# It executes a mock run of brain-score benchmarks.
check_models.check_base_models(__name__)
4 changes: 4 additions & 0 deletions brainscore_vision/models/BiT_S_R101x1/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
numpy
requests
torch
collections
8 changes: 8 additions & 0 deletions brainscore_vision/models/BiT_S_R101x1/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import brainscore_vision
import pytest


@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('BiT-S-R101x1')
assert model.identifier == 'BiT-S-R101x1'