5
5
'Aggregating Nested Transformers'
6
6
- https://arxiv.org/abs/2105.12723
7
7
8
- The official Jax code is released and available at https://github.com/google-research/nested-transformer
8
+ The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights
9
+ have been converted with convert/convert_nest_flax.py
9
10
10
11
Acknowledgments:
11
12
* The paper authors for sharing their research, code, and model weights
37
38
_logger = logging .getLogger (__name__ )
38
39
39
40
40
- # TODO check first_conv. everything else has been checked
41
41
def _cfg (url = '' , ** kwargs ):
42
42
return {
43
43
'url' : url ,
@@ -60,7 +60,6 @@ def _cfg(url='', **kwargs):
60
60
}
61
61
62
62
63
- # TODO - Leave note for Ross - Maybe we can generalize Attention to this and put it in layers
64
63
class Attention (nn .Module ):
65
64
"""
66
65
This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
@@ -102,7 +101,6 @@ class TransformerLayer(Block):
102
101
This is much like `.vision_transformer.Block` but:
103
102
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
104
103
- Uses modified Attention layer that handles the "block" dimension
105
- TODO somehow reuse the code instead of rewriting it...
106
104
"""
107
105
def __init__ (self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
108
106
act_layer = nn .GELU , norm_layer = nn .LayerNorm ):
@@ -172,6 +170,39 @@ def deblockify(x, block_size: int):
172
170
height = width = grid_size * block_size
173
171
x = x .reshape (B , height , width , C )
174
172
return x # (B, H, W, C)
173
+
174
+
175
+ class NestLevel (nn .Module ):
176
+ """ Single hierarchical level of a Nested Transformer
177
+ """
178
+ def __init__ (self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , mlp_ratio = 4. , qkv_bias = True ,
179
+ drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [], norm_layer = None , act_layer = None ):
180
+ super ().__init__ ()
181
+ self .block_size = block_size
182
+ self .pos_embed = nn .Parameter (torch .zeros (1 , num_blocks , seq_length , embed_dim ))
183
+ # Transformer encoder
184
+ if len (drop_path_rates ):
185
+ assert len (drop_path_rates ) == depth , 'Must provide as many drop path rates as there are transformer layers'
186
+ self .transformer_encoder = nn .Sequential (* [
187
+ TransformerLayer (
188
+ dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias ,
189
+ drop = drop_rate , attn_drop = attn_drop_rate , drop_path = drop_path_rates [i ],
190
+ norm_layer = norm_layer , act_layer = act_layer )
191
+ for i in range (depth )])
192
+
193
+ def forward (self , x ):
194
+ """
195
+ expects x as (B, C, H, W)
196
+ """
197
+ # Switch to channels last for transformer
198
+ x = x .permute (0 , 2 , 3 , 1 ) # (B, H', W', C)
199
+ x = blockify (x , self .block_size ) # (B, T, N, C')
200
+ x = x + self .pos_embed
201
+ x = self .transformer_encoder (x ) # (B, T, N, C')
202
+ x = deblockify (x , self .block_size ) # (B, H', W', C')
203
+ # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
204
+ x = x .permute (0 , 3 , 1 , 2 ) # (B, C, H', W')
205
+ return x
175
206
176
207
177
208
class Nest (nn .Module ):
@@ -182,10 +213,9 @@ class Nest(nn.Module):
182
213
"""
183
214
184
215
def __init__ (self , img_size = 224 , in_chans = 3 , patch_size = 4 , num_levels = 3 , embed_dims = (128 , 256 , 512 ),
185
- num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), num_classes = 1000 , mlp_ratio = 4. ,
186
- qkv_bias = True , pad_type = '' ,
187
- drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None ,
188
- act_layer = None , weight_init = '' , global_pool = 'avg' ):
216
+ num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True , pad_type = '' ,
217
+ drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None , weight_init = '' ,
218
+ global_pool = 'avg' ):
189
219
"""
190
220
Args:
191
221
img_size (int, tuple): input image size
@@ -203,7 +233,7 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
203
233
drop_path_rate (float): stochastic depth rate
204
234
norm_layer: (nn.Module): normalization layer for transformer layers
205
235
act_layer: (nn.Module): activation layer in MLP of transformer layers
206
- weight_init: (str): weight init scheme TODO check
236
+ weight_init: (str): weight init scheme
207
237
global_pool: (str): type of pooling operation to apply to final feature map
208
238
209
239
Notes:
@@ -247,45 +277,33 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
247
277
# Patch embedding
248
278
self .patch_embed = PatchEmbed (
249
279
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [0 ])
280
+ self .feature_info = [dict (num_chs = embed_dims [0 ], reduction = patch_size , module = 'patch_embed' )]
250
281
self .num_patches = self .patch_embed .num_patches
282
+ self .seq_length = self .num_patches // self .num_blocks [0 ]
251
283
252
284
# Build up each hierarchical level
253
- self .ls_pos_embed = []
254
- self .ls_transformer_encoder = nn .ModuleList ([])
255
- self .ls_block_aggregation = nn .ModuleList ([])
256
- dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))] # drop path rate
257
- self .feature_info = []
258
- for level in range (self .num_levels ):
259
- # Positional embedding
260
- # NOTE: Can't use ParameterList for positional embedding as it can't be enumerated with TorchScript
261
- pos_embed = nn .Parameter (
262
- torch .zeros (1 , self .num_blocks [level ], self .num_patches // self .num_blocks [0 ], embed_dims [level ]))
263
- self .register_parameter (f'pos_embed_{ level } ' , pos_embed )
264
- self .ls_pos_embed .append (pos_embed )
265
- # Transformer encoder
266
- self .ls_transformer_encoder .append (nn .Sequential (* [
267
- TransformerLayer (
268
- dim = embed_dims [level ], num_heads = num_heads [level ], mlp_ratio = mlp_ratio , qkv_bias = qkv_bias ,
269
- drop = drop_rate , attn_drop = attn_drop_rate , drop_path = dpr [sum (depths [:level ]) + i ],
270
- norm_layer = norm_layer , act_layer = act_layer )
271
- for i in range (depths [level ])]))
272
-
273
- self .feature_info .append (dict (
274
- num_chs = embed_dims [level ], reduction = 2 ,
275
- module = f'ls_transformer_encoder.{ level } .{ depths [level ]- 1 } .mlp.fc2' ))
276
-
277
- # Block aggregation (not required for last level)
278
- if level < self .num_levels - 1 :
279
- self .ls_block_aggregation .append (
280
- BlockAggregation (embed_dims [level ], embed_dims [level + 1 ], norm_layer , pad_type = pad_type ))
285
+ self .levels = nn .ModuleList ([])
286
+ self .block_aggs = nn .ModuleList ([])
287
+ drop_path_rates = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))]
288
+ for lix in range (self .num_levels ):
289
+ dpr = drop_path_rates [sum (depths [:lix ]):sum (depths [:lix + 1 ])]
290
+ self .levels .append (NestLevel (
291
+ self .num_blocks [lix ], self .block_size , self .seq_length , num_heads [lix ], depths [lix ],
292
+ embed_dims [lix ], mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dpr , norm_layer ,
293
+ act_layer ))
294
+ self .feature_info .append (
295
+ dict (num_chs = embed_dims [lix ], reduction = self .feature_info [- 1 ]['reduction' ]* 2 , module = f'levels.{ lix } ' ))
296
+ if lix < self .num_levels - 1 :
297
+ self .block_aggs .append (BlockAggregation (
298
+ embed_dims [lix ], embed_dims [lix + 1 ], norm_layer , pad_type = pad_type ))
281
299
else :
282
- # NOTE: Required for enumeration over all level components at once
283
- self .ls_block_aggregation .append (nn .Identity ())
284
- self .ls_pos_embed = tuple (self .ls_pos_embed ) # static length required for torchscript
300
+ # Required for zipped iteration over levels and ls_block_agg together
301
+ self .block_aggs .append (nn .Identity ())
285
302
286
-
287
303
# Final normalization layer
288
304
self .norm = norm_layer (embed_dims [- 1 ])
305
+ self .feature_info .append (
306
+ dict (num_chs = embed_dims [lix ], reduction = self .feature_info [- 1 ]['reduction' ], module = 'norm' ))
289
307
290
308
# Classifier
291
309
self .global_pool , self .head = create_classifier (
@@ -296,8 +314,8 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
296
314
def init_weights (self , mode = '' ):
297
315
assert mode in ('jax' , 'jax_nlhb' , 'nlhb' , '' )
298
316
head_bias = - math .log (self .num_classes ) if 'nlhb' in mode else 0.
299
- for pos_embed in self .ls_pos_embed :
300
- trunc_normal_ (pos_embed , std = .02 , a = - 2 , b = 2 )
317
+ for level in self .levels :
318
+ trunc_normal_ (level . pos_embed , std = .02 , a = - 2 , b = 2 )
301
319
if mode .startswith ('jax' ):
302
320
named_apply (partial (_init_nest_weights , head_bias = head_bias , jax_impl = True ), self )
303
321
else :
@@ -319,22 +337,13 @@ def forward_features(self, x):
319
337
""" x shape (B, C, H, W)
320
338
"""
321
339
B , _ , H , W = x .shape
322
- x = self .patch_embed (x ) # (B, N, C)
340
+ x = self .patch_embed (x )
323
341
x = x .reshape (B , H // self .patch_size , W // self .patch_size , - 1 ) # (B, H', W', C')
324
- # NOTE: TorchScript wants enumeration rather than subscripting of ModuleList
325
- for level , (pos_embed , transformer , block_agg ) in enumerate (
326
- zip (self .ls_pos_embed , self .ls_transformer_encoder , self .ls_block_aggregation )):
327
- if level > 0 :
328
- # Switch back to channels last for transformer
329
- x = x .permute (0 , 2 , 3 , 1 ) # (B, H', W', C)
330
- x = blockify (x , self .block_size ) # (B, T, N, C')
331
- x = x + pos_embed
332
- x = transformer (x ) # (B, T, N, C')
333
- x = deblockify (x , self .block_size ) # (B, H', W', C')
334
- # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
335
- x = x .permute (0 , 3 , 1 , 2 ) # (B, C, H', W')
336
- if level < self .num_levels - 1 :
337
- x = block_agg (x ) # (B, C', H'//2, W'//2)
342
+ x = x .permute (0 , 3 , 1 , 2 )
343
+ # NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
344
+ for level , block_agg in zip (self .levels , self .block_aggs ):
345
+ x = level (x )
346
+ x = block_agg (x )
338
347
# Layer norm done over channel dim only
339
348
x = self .norm (x .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
340
349
return x
@@ -404,11 +413,12 @@ def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
404
413
# raise RuntimeError('features_only not implemented for Vision Transformer models.')
405
414
406
415
default_cfg = default_cfg or default_cfgs [variant ]
407
-
408
416
model = build_model_with_cfg (
409
417
Nest , variant , pretrained ,
410
418
default_cfg = default_cfg ,
411
419
pretrained_filter_fn = checkpoint_filter_fn ,
420
+ feature_cfg = dict (
421
+ out_indices = tuple (range (kwargs .get ('num_levels' , 3 ) + 2 )), feature_cls = 'hook' , flatten_sequential = True ),
412
422
** kwargs )
413
423
414
424
return model
@@ -478,3 +488,10 @@ def jx_nest_tiny(pretrained=False, **kwargs):
478
488
embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 8 ), drop_path_rate = 0.2 , ** kwargs )
479
489
model = _create_nest ('jx_nest_tiny' , pretrained = pretrained , ** model_kwargs )
480
490
return model
491
+
492
+
493
+ if __name__ == '__main__' :
494
+ model = jx_nest_base ()
495
+ model = torch .jit .script (model )
496
+ inp = torch .zeros (8 , 3 , 224 , 224 )
497
+ print (model .forward_features (inp ).shape )
0 commit comments