Skip to content

Commit 0d124ff

Browse files
committed
Update README. Fine-grained layer-wise lr decay working for tiny_vit and both efficientvits. Minor fixes.
1 parent 2f0fbb5 commit 0d124ff

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
3535
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
3636
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
3737

38+
### Sep 1, 2023
39+
* TinyViT added by [SeeFun](https://github.com/seefun)
40+
* Fix EfficientViT (MIT) to use torch.autocast so it works back to PT 1.10
41+
3842
### Aug 28, 2023
3943
* Add dynamic img size support to models in `vision_transformer.py`, `vision_transformer_hybrid.py`, `deit.py`, and `eva.py` w/o breaking backward compat.
4044
* Add `dynamic_img_size=True` to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass).

timm/models/efficientvit_mit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,10 @@ def __init__(
508508

509509
# stages
510510
self.feature_info = []
511-
stages = []
512-
stage_idx = 0
511+
self.stages = nn.Sequential()
513512
in_channels = widths[0]
514513
for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
515-
stages.append(EfficientVitStage(
514+
self.stages.append(EfficientVitStage(
516515
in_channels,
517516
w,
518517
depth=d,
@@ -524,10 +523,8 @@ def __init__(
524523
))
525524
stride *= 2
526525
in_channels = w
527-
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
528-
stage_idx += 1
526+
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
529527

530-
self.stages = nn.Sequential(*stages)
531528
self.num_features = in_channels
532529
self.head_widths = head_widths
533530
self.head_dropout = drop_rate
@@ -548,8 +545,11 @@ def __init__(
548545
@torch.jit.ignore
549546
def group_matcher(self, coarse=False):
550547
matcher = dict(
551-
stem=r'^stem', # stem and embed
552-
blocks=[(r'^stages\.(\d+)', None)]
548+
stem=r'^stem',
549+
blocks=r'^stages\.(\d+)' if coarse else [
550+
(r'^stages\.(\d+).downsample', (0,)),
551+
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
552+
]
553553
)
554554
return matcher
555555

timm/models/efficientvit_msra.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,18 @@ def __init__(
441441
self.head = NormLinear(
442442
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
443443

444+
@torch.jit.ignore
445+
def no_weight_decay(self):
446+
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
447+
444448
@torch.jit.ignore
445449
def group_matcher(self, coarse=False):
446450
matcher = dict(
447451
stem=r'^patch_embed',
448-
blocks=[(r'^stages\.(\d+)', None)]
452+
blocks=r'^stages\.(\d+)' if coarse else [
453+
(r'^stages\.(\d+).downsample', (0,)),
454+
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
455+
]
449456
)
450457
return matcher
451458

@@ -455,7 +462,7 @@ def set_grad_checkpointing(self, enable=True):
455462

456463
@torch.jit.ignore
457464
def get_classifier(self):
458-
return self.head
465+
return self.head.linear
459466

460467
def reset_classifier(self, num_classes, global_pool=None):
461468
self.num_classes = num_classes

timm/models/tiny_vit.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,11 +509,18 @@ def _init_weights(self, m):
509509
def no_weight_decay_keywords(self):
510510
return {'attention_biases'}
511511

512+
@torch.jit.ignore
513+
def no_weight_decay(self):
514+
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
515+
512516
@torch.jit.ignore
513517
def group_matcher(self, coarse=False):
514518
matcher = dict(
515519
stem=r'^patch_embed',
516-
blocks=[(r'^stages\.(\d+)', None)]
520+
blocks=r'^stages\.(\d+)' if coarse else [
521+
(r'^stages\.(\d+).downsample', (0,)),
522+
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
523+
]
517524
)
518525
return matcher
519526

@@ -523,7 +530,7 @@ def set_grad_checkpointing(self, enable=True):
523530

524531
@torch.jit.ignore
525532
def get_classifier(self):
526-
return self.head
533+
return self.head.fc
527534

528535
def reset_classifier(self, num_classes, global_pool=None):
529536
self.num_classes = num_classes

0 commit comments

Comments
 (0)