1+ # Custom Pytorch model from:
2+ # https://github.com/brain-score/candidate_models/blob/master/examples/score-model.ipynb
3+ from brainscore_vision .model_helpers .check_submission import check_models
4+ import numpy as np
5+ import torch
6+ #from torch import nn
7+ import functools
8+ from brainscore_vision .model_helpers .activations .pytorch import PytorchWrapper
9+ from brainscore_vision .model_helpers .brain_transformation import ModelCommitment
10+ from brainscore_vision .model_helpers .activations .pytorch import load_preprocess_images
11+ from brainscore_vision .model_helpers .s3 import load_weight_file
12+ import torch
13+ import torch .nn as nn
14+ import torch .nn .functional as F
15+ from torch .nn import init
16+ from functools import reduce
17+
18+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
19+
20+ # define your custom model here:
21+ class SKConv (nn .Module ):
22+ def __init__ (self ,in_channels ,out_channels ,stride = 1 ,M = 2 ,r = 16 ,L = 32 , groups = 32 ):
23+
24+ super (SKConv ,self ).__init__ ()
25+ d = max (in_channels // r ,L )
26+ self .M = M
27+ self .out_channels = out_channels
28+ self .conv = nn .ModuleList ()
29+ for i in range (M ):
30+
31+ conv1 = nn .Conv2d (in_channels ,out_channels ,3 ,stride ,padding = 1 + i ,dilation = 1 + i ,groups = groups ,bias = False )
32+ init .kaiming_normal_ (conv1 .weight )
33+ self .conv .append (nn .Sequential (conv1 ,
34+ nn .BatchNorm2d (out_channels ),
35+ nn .ReLU (inplace = True )))
36+ self .global_pool = nn .AdaptiveAvgPool2d (1 )
37+ conv_fc = nn .Conv2d (out_channels ,d ,1 ,bias = False )
38+ init .normal_ (conv_fc .weight , std = 0.01 )
39+ self .fc1 = nn .Sequential (conv_fc ,
40+ nn .BatchNorm2d (d ),
41+ nn .ReLU (inplace = True ))
42+ self .fc2 = nn .Conv2d (d ,out_channels * M ,1 ,1 ,bias = False )
43+ init .normal_ (self .fc2 .weight , std = 0.01 )
44+ self .softmax = nn .Softmax (dim = 1 )
45+
46+ def forward (self , input ):
47+ batch_size = input .size (0 )
48+ output = []
49+ for i ,conv in enumerate (self .conv ):
50+ output .append (conv (input ))
51+ U = reduce (lambda x ,y :x + y ,output )
52+ s = self .global_pool (U )
53+ z = self .fc1 (s )
54+ a_b = self .fc2 (z )
55+ a_b = a_b .reshape (batch_size ,self .M ,self .out_channels ,- 1 )
56+ a_b = self .softmax (a_b )
57+ a_b = list (a_b .chunk (self .M ,dim = 1 ))
58+ a_b = list (map (lambda x :x .reshape (batch_size ,self .out_channels ,1 ,1 ),a_b ))
59+ V = list (map (lambda x ,y :x * y ,output ,a_b ))
60+ V = reduce (lambda x ,y :x + y ,V )
61+ return V
62+
63+ class GRCL (nn .Module ):
64+ def __init__ (self , inplanes , planes , downsample = True , iter = 3 , SKconv = True , expansion = 2 ):
65+ super (GRCL , self ).__init__ ()
66+
67+ self .iter = iter
68+ self .expansion = expansion
69+ # feed-forward part
70+ self .add_module ('bn_f' , nn .BatchNorm2d (inplanes ))
71+ self .add_module ('relu_f' , nn .ReLU (inplace = True ))
72+ conv_f = nn .Conv2d (inplanes , int (planes * self .expansion ), kernel_size = 3 , stride = 1 , padding = 1 , bias = False , groups = 32 )
73+ init .kaiming_normal_ (conv_f .weight )
74+ self .add_module ('conv_f' , conv_f )
75+
76+ self .add_module ('bn_g_f' , nn .BatchNorm2d (inplanes ))
77+ self .add_module ('relu_g_f' , nn .ReLU (inplace = True ))
78+ conv_g_f = nn .Conv2d (inplanes , int (planes * self .expansion ), kernel_size = 1 , stride = 1 , padding = 0 , bias = True , groups = 32 )
79+ init .normal_ (conv_g_f .weight , std = 0.01 )
80+ self .add_module ('conv_g_f' , conv_g_f )
81+ self .conv_g_r = nn .Conv2d (int (planes * self .expansion ), int (planes * self .expansion ), kernel_size = 1 , stride = 1 , padding = 0 , bias = False , groups = 32 )
82+ self .add_module ('sig' , nn .Sigmoid ())
83+
84+ # recurrent part
85+ for i in range (0 , self .iter ):
86+ layers = []
87+ layers_g_bn = []
88+
89+ layers .append (nn .BatchNorm2d (planes * self .expansion ))
90+ layers .append (nn .ReLU (inplace = True ))
91+ conv_1 = nn .Conv2d (int (planes * self .expansion ), planes , kernel_size = 1 , stride = 1 , padding = 0 , bias = False )
92+ init .kaiming_normal_ (conv_1 .weight )
93+ layers .append (conv_1 )
94+
95+ layers .append (nn .BatchNorm2d (planes ))
96+ layers .append (nn .ReLU (inplace = True ))
97+
98+ if SKconv :
99+ layers .append (SKConv (planes , planes ))
100+ else :
101+ layers .append (nn .Conv2d (planes , planes , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
102+ layers .append (nn .BatchNorm2d (planes ))
103+ layers .append (nn .ReLU (inplace = True ))
104+
105+ conv_2 = nn .Conv2d (planes , int (planes * self .expansion ), kernel_size = 1 , stride = 1 , padding = 0 , bias = False )
106+ init .kaiming_normal_ (conv_2 .weight )
107+ layers .append (conv_2 )
108+ layers_g_bn .append (nn .BatchNorm2d (int (planes * self .expansion )))
109+
110+ layers_g_bn .append (nn .ReLU (inplace = True ))
111+
112+ self .add_module ('iter_' + str (i + 1 ), nn .Sequential (* layers ))
113+ self .add_module ('iter_g_' + str (i + 1 ), nn .Sequential (* layers_g_bn ))
114+
115+ self .downsample = downsample
116+ if self .downsample :
117+ self .add_module ('d_bn' , nn .BatchNorm2d (planes * self .expansion ))
118+ self .add_module ('d_relu' , nn .ReLU (inplace = True ))
119+ d_conv = nn .Conv2d (int (planes * self .expansion ), int (planes * self .expansion ), kernel_size = 1 , stride = 1 , padding = 0 , bias = False )
120+ init .kaiming_normal_ (d_conv .weight )
121+ self .add_module ('d_conv' , d_conv )
122+ self .add_module ('d_ave' , nn .AvgPool2d ((2 , 2 ), stride = 2 ))
123+
124+ self .add_module ('d_bn_1' , nn .BatchNorm2d (planes * self .expansion ))
125+ self .add_module ('d_relu_1' , nn .ReLU (inplace = True ))
126+ d_conv_1 = nn .Conv2d (int (planes * self .expansion ), planes , kernel_size = 1 , stride = 1 , padding = 0 ,
127+ bias = False )
128+ init .kaiming_normal_ (d_conv_1 .weight )
129+ self .add_module ('d_conv_1' , d_conv_1 )
130+
131+ self .add_module ('d_bn_3' , nn .BatchNorm2d (planes ))
132+ self .add_module ('d_relu_3' , nn .ReLU (inplace = True ))
133+
134+ if SKconv :
135+ d_conv_3 = SKConv (planes , planes , stride = 2 )
136+ self .add_module ('d_conv_3' , d_conv_3 )
137+ else :
138+ d_conv_3 = nn .Conv2d (planes , planes , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
139+ init .kaiming_normal_ (d_conv_3 .weight )
140+ self .add_module ('d_conv_3' , d_conv_3 )
141+
142+ d_conv_1e = nn .Conv2d (planes , int (planes * self .expansion ), kernel_size = 1 , stride = 1 , padding = 0 , bias = False )
143+ init .kaiming_normal_ (d_conv_1e .weight )
144+ self .add_module ('d_conv_1e' , d_conv_1e )
145+
146+ def forward (self , x ):
147+ # feed-forward
148+ x_bn = self .bn_f (x )
149+ x_act = self .relu_f (x_bn )
150+ x_s = self .conv_f (x_act )
151+
152+ x_g_bn = self .bn_g_f (x )
153+ x_g_act = self .relu_g_f (x_g_bn )
154+ x_g_s = self .conv_g_f (x_g_act )
155+
156+ # recurrent
157+ for i in range (0 , self .iter ):
158+ x_g_r = self .conv_g_r (self .__dict__ ['_modules' ]["iter_g_%s" % str (i + 1 )](x_s ))
159+ x_s = self .__dict__ ['_modules' ]["iter_%s" % str (i + 1 )](x_s ) * torch .sigmoid (x_g_r + x_g_s ) + x_s
160+
161+ if self .downsample :
162+ x_s_1 = self .d_conv (self .d_ave (self .d_relu (self .d_bn (x_s ))))
163+ x_s_2 = self .d_conv_1e (self .d_conv_3 (self .d_relu_3 (self .d_bn_3 (self .d_conv_1 (self .d_relu_1 (self .d_bn_1 (x_s )))))))
164+ x_s = x_s_1 + x_s_2
165+
166+ return x_s
167+
168+ class GRCNN (nn .Module ):
169+
170+ def __init__ (self , iters , maps , SKconv , expansion , num_classes ):
171+ """ Args:
172+ iters:iterations.
173+ num_classes: number of classes
174+ """
175+ super (GRCNN , self ).__init__ ()
176+ self .iters = iters
177+ self .maps = maps
178+ self .num_classes = num_classes
179+ self .expansion = expansion
180+
181+ self .conv1 = nn .Conv2d (3 , 64 , kernel_size = 7 , stride = 2 , padding = 3 ,
182+ bias = False )
183+
184+ init .kaiming_normal_ (self .conv1 .weight )
185+
186+ self .bn1 = nn .BatchNorm2d (64 )
187+ self .relu = nn .ReLU (inplace = True )
188+ self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
189+
190+ self .conv2 = nn .Conv2d (64 , 64 , kernel_size = 3 , stride = 1 , padding = 1 ,
191+ bias = False )
192+
193+ init .kaiming_normal_ (self .conv2 .weight )
194+
195+ self .layer1 = GRCL (64 , self .maps [0 ], True , self .iters [0 ], SKconv , self .expansion )
196+ self .layer2 = GRCL (self .maps [0 ] * self .expansion , self .maps [1 ], True , self .iters [1 ], SKconv , self .expansion )
197+ self .layer3 = GRCL (self .maps [1 ] * self .expansion , self .maps [2 ], True , self .iters [2 ], SKconv , self .expansion )
198+ self .layer4 = GRCL (self .maps [2 ] * self .expansion , self .maps [3 ], False , self .iters [3 ], SKconv , self .expansion )
199+
200+ self .lastact = nn .Sequential (nn .BatchNorm2d (self .maps [3 ]* self .expansion ), nn .ReLU (inplace = True ))
201+ self .avgpool = nn .AvgPool2d (7 )
202+ self .classifier = nn .Linear (self .maps [3 ] * self .expansion , num_classes )
203+
204+ for m in self .modules ():
205+ if isinstance (m , nn .Conv2d ):
206+ if m .bias is not None :
207+ init .zeros_ (m .bias )
208+ elif isinstance (m , nn .BatchNorm2d ):
209+ init .ones_ (m .weight )
210+ init .zeros_ (m .bias )
211+ elif isinstance (m , nn .Linear ):
212+ init .kaiming_normal_ (m .weight )
213+ init .zeros_ (m .bias )
214+
215+ def forward (self , x ):
216+ x = self .conv1 (x )
217+ x = self .bn1 (x )
218+ x = self .relu (x )
219+ x = self .maxpool (x )
220+ x = self .conv2 (x )
221+
222+ x = self .layer1 (x )
223+ x = self .layer2 (x )
224+ x = self .layer3 (x )
225+ x = self .layer4 (x )
226+
227+ x = self .lastact (x )
228+ x = self .avgpool (x )
229+ x = x .view (x .size (0 ), - 1 )
230+ return self .classifier (x )
231+
232+ def grcnn55 (num_classes = 1000 ):
233+ """
234+ Args:
235+ num_classes (uint): number of classes
236+ """
237+ model = GRCNN ([3 , 3 , 4 , 3 ], [64 , 128 , 256 , 512 ], SKconv = False , expansion = 4 , num_classes = num_classes )
238+ return model
239+
240+
241+
242+ #dir_path = os.path.dirname(os.path.realpath(""))
243+
244+ weights_path = load_weight_file (bucket = "brainscore-vision" , folder_name = "models" ,
245+ relative_path = "grcnn/checkpoint_params_grcnn55.pt" ,
246+ version_id = "SnkgwO32ntpKnS9UzOz8RecLiaDK6iYn" ,
247+ sha1 = "20fb844e72f21aeb257c053adb2b645bc954839e" )
248+ checkpoint = torch .load (weights_path , map_location = device )
249+ model_ft = grcnn55 () #models.resnet50(pretrained=True)
250+ model_ft .load_state_dict (checkpoint )
251+ model_ft = model_ft .to (device )
252+
253+
254+
255+ # get_model method actually gets the model. For a custom model, this is just linked to the
256+ # model we defined above.
257+ def get_model (name ):
258+ """
259+ This method fetches an instance of a base model. The instance has to be callable and return a xarray object,
260+ containing activations. There exist standard wrapper implementations for common libraries, like pytorch and
261+ keras. Checkout the examples folder, to see more. For custom implementations check out the implementation of the
262+ wrappers.
263+ :param name: the name of the model to fetch
264+ :return: the model instance
265+ """
266+ assert name == 'grcnn'
267+ # link the custom model to the wrapper object(activations_model above):
268+ preprocessing = functools .partial (load_preprocess_images , image_size = 224 )
269+ wrapper = PytorchWrapper (identifier = 'grcnn' , model = model_ft , preprocessing = preprocessing )
270+ wrapper .image_size = 224
271+ return wrapper
272+
273+
274+ # get_layers method to tell the code what layers to consider. If you are submitting a custom
275+ # model, then you will most likley need to change this method's return values.
276+ def get_layers (name ):
277+ """
278+ This method returns a list of string layer names to consider per model. The benchmarks maps brain regions to
279+ layers and uses this list as a set of possible layers. The lists doesn't have to contain all layers, the less the
280+ faster the benchmark process works. Additionally the given layers have to produce an activations vector of at least
281+ size 25! The layer names are delivered back to the model instance and have to be resolved in there. For a pytorch
282+ model, the layer name are for instance dot concatenated per module, e.g. "features.2".
283+ :param name: the name of the model, to return the layers for
284+ :return: a list of strings containing all layers, that should be considered as brain area.
285+ """
286+
287+ # quick check to make sure the model is the correct one:
288+ assert name == 'grcnn'
289+
290+ # returns the layers you want to consider
291+ return [layer for layer , _ in model_ft .named_modules ()][1 :]
292+
293+ # Bibtex Method. For submitting a custom model, you can either put your own Bibtex if your
294+ # model has been published, or leave the empty return value if there is no publication to refer to.
295+ def get_bibtex (model_identifier ):
296+ """
297+ A method returning the bibtex reference of the requested model as a string.
298+ """
299+
300+ # from pytorch.py:
301+ return '''@misc{cheng2020grcnngraphrecognitionconvolutional,
302+ title={GRCNN: Graph Recognition Convolutional Neural Network for Synthesizing Programs from Flow Charts},
303+ author={Lin Cheng and Zijiang Yang},
304+ year={2020},
305+ eprint={2011.05980},
306+ archivePrefix={arXiv},
307+ primaryClass={cs.CV},
308+ url={https://arxiv.org/abs/2011.05980},
309+ }'''
310+
311+ # Main Method: In submitting a custom model, you should not have to mess with this.
312+ if __name__ == '__main__' :
313+ # Use this method to ensure the correctness of the BaseModel implementations.
314+ # It executes a mock run of brain-score benchmarks.
315+ check_models .check_base_models (__name__ )
0 commit comments