Closed
Description
It seems that VisionTransformer
doesn't support feature extraction of all outputs in the forward_features
method. Only returning of the cls token or [cls_token, distillation_token] is available timm/models/vision_transformer.py#L291-L304. This functionality seems particularly useful similar to how pretrained ResNets features are commonly used for downstream tasks.
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
However this is available for other models e.g.
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
I implemented something similar for a side project here that required all ViT outputs for some a downstream segmentation task. I simply override the method in my example, but I assume some attribute could be added to VisionTransformer
to allow for returning 'unpooled' output. Maybe something like this:
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
if self.unpooled:
if self.dist_token is None:
return x[:, 1:]
else:
return x[:, 2:]
else:
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
if not self.unpooled:
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x