Skip to content

Commit c681a8c

Browse files
authored
Converted BiT-S-R50x1 (#1045)
* Add BiT-S-R50x1 * a
1 parent c997bd5 commit c681a8c

File tree

4 files changed

+237
-0
lines changed

4 files changed

+237
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
2+
from brainscore_vision import model_registry
3+
from .model import get_layers,get_model
4+
5+
6+
model_registry['BiT-S-R50x1'] = \
7+
lambda: ModelCommitment(identifier='BiT-S-R50x1', activations_model=get_model('BiT-S-R50x1'), layers=get_layers('BiT-S-R50x1'))
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from collections import OrderedDict # pylint: disable=g-importing-member
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import functools
6+
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
7+
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
8+
from brainscore_vision.model_helpers.check_submission import check_models
9+
import requests
10+
import numpy as np
11+
import io
12+
13+
class StdConv2d(nn.Conv2d):
14+
15+
def forward(self, x):
16+
w = self.weight
17+
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
18+
w = (w - m) / torch.sqrt(v + 1e-10)
19+
return F.conv2d(x, w, self.bias, self.stride, self.padding,
20+
self.dilation, self.groups)
21+
22+
23+
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
24+
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
25+
padding=1, bias=bias, groups=groups)
26+
27+
28+
def conv1x1(cin, cout, stride=1, bias=False):
29+
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
30+
padding=0, bias=bias)
31+
32+
33+
def tf2th(conv_weights):
34+
"""Possibly convert HWIO to OIHW."""
35+
if conv_weights.ndim == 4:
36+
conv_weights = conv_weights.transpose([3, 2, 0, 1])
37+
return torch.from_numpy(conv_weights)
38+
39+
40+
class PreActBottleneck(nn.Module):
41+
"""Pre-activation (v2) bottleneck block.
42+
43+
Follows the implementation of "Identity Mappings in Deep Residual Networks":
44+
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
45+
46+
Except it puts the stride on 3x3 conv when available.
47+
"""
48+
49+
def __init__(self, cin, cout=None, cmid=None, stride=1):
50+
super().__init__()
51+
cout = cout or cin
52+
cmid = cmid or cout//4
53+
54+
self.gn1 = nn.GroupNorm(32, cin)
55+
self.conv1 = conv1x1(cin, cmid)
56+
self.gn2 = nn.GroupNorm(32, cmid)
57+
self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
58+
self.gn3 = nn.GroupNorm(32, cmid)
59+
self.conv3 = conv1x1(cmid, cout)
60+
self.relu = nn.ReLU(inplace=True)
61+
62+
if (stride != 1 or cin != cout):
63+
# Projection also with pre-activation according to paper.
64+
self.downsample = conv1x1(cin, cout, stride)
65+
66+
def forward(self, x):
67+
out = self.relu(self.gn1(x))
68+
69+
# Residual branch
70+
residual = x
71+
if hasattr(self, 'downsample'):
72+
residual = self.downsample(out)
73+
74+
# Unit's branch
75+
out = self.conv1(out)
76+
out = self.conv2(self.relu(self.gn2(out)))
77+
out = self.conv3(self.relu(self.gn3(out)))
78+
79+
return out + residual
80+
81+
def load_from(self, weights, prefix=''):
82+
convname = 'standardized_conv2d'
83+
with torch.no_grad():
84+
self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel']))
85+
self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel']))
86+
self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel']))
87+
self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma']))
88+
self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma']))
89+
self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma']))
90+
self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta']))
91+
self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta']))
92+
self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta']))
93+
if hasattr(self, 'downsample'):
94+
w = weights[f'{prefix}a/proj/{convname}/kernel']
95+
self.downsample.weight.copy_(tf2th(w))
96+
97+
98+
class ResNetV2(nn.Module):
99+
"""Implementation of Pre-activation (v2) ResNet mode."""
100+
101+
def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
102+
super().__init__()
103+
wf = width_factor # shortcut 'cause we'll use it a lot.
104+
105+
# The following will be unreadable if we split lines.
106+
# pylint: disable=line-too-long
107+
self.root = nn.Sequential(OrderedDict([
108+
('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
109+
('pad', nn.ConstantPad2d(1, 0)),
110+
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
111+
# The following is subtly not the same!
112+
# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
113+
]))
114+
115+
self.body = nn.Sequential(OrderedDict([
116+
('block1', nn.Sequential(OrderedDict(
117+
[('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] +
118+
[(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
119+
))),
120+
('block2', nn.Sequential(OrderedDict(
121+
[('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
122+
[(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
123+
))),
124+
('block3', nn.Sequential(OrderedDict(
125+
[('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
126+
[(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
127+
))),
128+
('block4', nn.Sequential(OrderedDict(
129+
[('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
130+
[(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
131+
))),
132+
]))
133+
# pylint: enable=line-too-long
134+
135+
self.zero_head = zero_head
136+
self.head = nn.Sequential(OrderedDict([
137+
('gn', nn.GroupNorm(32, 2048*wf)),
138+
('relu', nn.ReLU(inplace=True)),
139+
('avg', nn.AdaptiveAvgPool2d(output_size=1)),
140+
('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
141+
]))
142+
143+
def forward(self, x):
144+
x = self.head(self.body(self.root(x)))
145+
assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
146+
return x[...,0,0]
147+
148+
def load_from(self, weights, prefix='resnet/'):
149+
with torch.no_grad():
150+
self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
151+
self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
152+
self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
153+
if self.zero_head:
154+
nn.init.zeros_(self.head.conv.weight)
155+
nn.init.zeros_(self.head.conv.bias)
156+
else:
157+
self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long
158+
self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
159+
160+
for bname, block in self.body.named_children():
161+
for uname, unit in block.named_children():
162+
unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
163+
164+
165+
KNOWN_MODELS = OrderedDict([
166+
('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
167+
])
168+
ALL_MODELS = list(KNOWN_MODELS.keys())
169+
170+
R50_LAYERS = [
171+
'body.block1.unit01.relu', 'body.block1.unit02.relu',
172+
'body.block1.unit03.relu', 'body.block2.unit01.relu',
173+
'body.block2.unit02.relu', 'body.block2.unit03.relu',
174+
'body.block2.unit04.relu', 'body.block3.unit01.relu',
175+
'body.block3.unit02.relu', 'body.block3.unit03.relu',
176+
'body.block3.unit04.relu', 'body.block3.unit05.relu',
177+
'body.block3.unit06.relu', 'body.block4.unit01.relu',
178+
'body.block4.unit02.relu', 'body.block4.unit03.relu'
179+
]
180+
181+
def get_model_list():
182+
return ALL_MODELS
183+
184+
def get_weights(bit_variant):
185+
response = requests.get(f'https://storage.googleapis.com/bit_models/{bit_variant}.npz')
186+
response.raise_for_status()
187+
return np.load(io.BytesIO(response.content))
188+
189+
def get_model(name):
190+
assert name == "BiT-S-R50x1"
191+
model = KNOWN_MODELS[name](head_size=1000) # Small BiTs are pretrained on ImageNet
192+
weights = get_weights(name)
193+
model.load_from(weights)
194+
model.eval()
195+
image_size = 224
196+
preprocessing = functools.partial(load_preprocess_images, image_size=image_size)
197+
wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)
198+
wrapper.image_size = image_size
199+
return wrapper
200+
201+
def get_layers(name):
202+
assert name == "BiT-S-R50x1"
203+
return R50_LAYERS
204+
205+
206+
def get_bibtex(model_identifier):
207+
return """@article{touvron2020deit,
208+
title={Training data-efficient image transformers & distillation through attention},
209+
author={Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Herv\'e J\'egou},
210+
journal={arXiv preprint arXiv:2012.12877},
211+
year={2020}
212+
}"""
213+
214+
215+
if __name__ == '__main__':
216+
# Use this method to ensure the correctness of the BaseModel implementations.
217+
# It executes a mock run of brain-score benchmarks.
218+
check_models.check_base_models(__name__)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
requests
2+
torch
3+
collections
4+
numpy
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import brainscore_vision
2+
import pytest
3+
4+
5+
@pytest.mark.travis_slow
6+
def test_has_identifier():
7+
model = brainscore_vision.load_model('BiT-S-R50x1')
8+
assert model.identifier == 'BiT-S-R50x1'

0 commit comments

Comments
 (0)