From 3055411c1bde27c18ac8d654ce85c81731622754 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 20 Aug 2023 21:22:59 -0700 Subject: [PATCH] Fix samvit bug, add F.sdpa support and ROPE option (#1920) * Fix a bug I introduced in samvit, add F.sdpa support and ROPE option to samvit, neck is LayerNorm if not used and standard classifier used * Add attn dropout to F.sdpa * Fix fx trace for sam vit * Fixing torchscript issues in samvit * Another torchscript fix * samvit head fc name fix --- timm/models/vision_transformer_sam.py | 272 ++++++++++++++++---------- 1 file changed, 174 insertions(+), 98 deletions(-) diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 9beb7b0162..53c49b071e 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -17,13 +17,15 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint +from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\ - Format, resample_abs_pos_embed_nhwc + Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model +from ._features_fx import register_notrace_function # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerSAM'] @@ -32,7 +34,77 @@ _logger = logging.getLogger(__name__) +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + +register_notrace_function(get_rel_pos) + + +def get_decomposed_rel_pos_bias( + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + Args: + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + bias (Tensor): attention bias to add to attention map + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn_bias = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + return attn_bias.reshape(-1, q_h * q_w, k_h * k_w) + + class Attention(nn.Module): + fused_attn: Final[bool] def __init__( self, @@ -44,14 +116,15 @@ def __init__( proj_drop=0., norm_layer=nn.LayerNorm, use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, + rope: Optional[nn.Module] = None, ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -61,6 +134,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) self.use_rel_pos = use_rel_pos if self.use_rel_pos: + assert rope is None assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." @@ -69,26 +143,45 @@ def __init__( 2 * input_size[0] - 1, self.head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros( 2 * input_size[1] - 1, self.head_dim)) + self.rope = rope def forward(self, x): B, H, W, _ = x.shape - qkv = self.qkv(x).reshape( - B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + N = H * W + x = x.reshape(B, N, -1) + qkv = self.qkv(x).view(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # qkv with shape (3, B, nHead, H * W, C) - q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0) # q, k, v with shape (B * nHead, H * W, C) q, k = self.q_norm(q), self.k_norm(k) - q = q * self.scale - attn = q @ k.transpose(-2, -1) if self.use_rel_pos: - attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + attn_bias = get_decomposed_rel_pos_bias(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + else: + attn_bias = None + if self.rope is not None: + rope = self.rope.get_embed() + q = apply_rot_embed_cat(q, rope).type_as(v) + k = apply_rot_embed_cat(k, rope).type_as(v) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.view(B, self.num_heads, N, -1).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) - + x = x.view(B, H, W, -1) return x @@ -121,6 +214,7 @@ def __init__( use_rel_pos=False, window_size=0, input_size=None, + rope=None, ): super().__init__() self.window_size = window_size @@ -135,6 +229,7 @@ def __init__( norm_layer=norm_layer, use_rel_pos=use_rel_pos, input_size=input_size if window_size == 0 else (window_size, window_size), + rope=rope, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -150,20 +245,26 @@ def __init__( self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): + B, H, W, _ = x.shape + shortcut = x x = self.norm1(x) # Window partition + pad_hw: Optional[Tuple[int, int]] = None if self.window_size > 0: - H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.drop_path1(self.ls1(self.attn(x))) + # Reverse window partition if self.window_size > 0: - x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = window_unpartition(x, self.window_size, (H, W), pad_hw) x = shortcut + x + + x = x.reshape(B, H * W, -1) # MLP is faster for N, L, C tensor x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + x = x.reshape(B, H, W, -1) return x @@ -183,8 +284,7 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size - if pad_h > 0 or pad_w > 0: - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) @@ -193,7 +293,7 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T def window_unpartition( - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] + windows: torch.Tensor, window_size: int, hw: Tuple[int, int], pad_hw: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. @@ -206,90 +306,15 @@ def window_unpartition( Returns: x: unpartitioned sequences with [B, H, W, C]. """ - Hp, Wp = pad_hw + Hp, Wp = pad_hw if pad_hw is not None else hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) - - if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :].contiguous() return x -def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - Args: - q_size (int): size of query q. - k_size (int): size of key k. - rel_pos (Tensor): relative position embeddings (L, C). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - -def add_decomposed_rel_pos( - attn: torch.Tensor, - q: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], -) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py - Args: - attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - q_h, q_w = q_size - k_h, k_w = k_size - Rh = get_rel_pos(q_h, k_h, rel_pos_h) - Rw = get_rel_pos(q_w, k_w, rel_pos_w) - - B, _, dim = q.shape - r_q = q.reshape(B, q_h, q_w, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) - - attn = ( - attn.view(B, q_h, q_w, k_h, k_w) + - rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - ).view(B, q_h * q_w, k_h * k_w) - - return attn - - class VisionTransformerSAM(nn.Module): """ Vision Transformer for Segment-Anything Model(SAM) @@ -326,11 +351,13 @@ def __init__( mlp_layer: Callable = Mlp, use_abs_pos: bool = True, use_rel_pos: bool = False, + use_rope: bool = False, window_size: int = 14, global_attn_indexes: Tuple[int, ...] = (), neck_chans: int = 256, global_pool: str = 'avg', - head_hidden_size: Optional[int] = None + head_hidden_size: Optional[int] = None, + ref_feat_shape: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None ): """ Args: @@ -356,10 +383,12 @@ def __init__( block_fn: Transformer block layer. use_abs_pos: If True, use absolute positional embeddings. use_rel_pos: If True, add relative positional embeddings to the attention map. + use_rope: If True, add rotary position embeddings to q/k in attention block. window_size: Window size for window attention blocks. If 0, not use window attention. global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0. global_pool: Global pooling type. head_hidden_size: If set, use NormMlpHead + ref_feat_shape: Tuple of reference feature shapes for ROPE, (global, local) """ super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) @@ -394,6 +423,30 @@ def __init__( self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + if use_rope: + assert not use_rel_pos, "ROPE and relative pos embeddings should not be enabled at same time" + if ref_feat_shape is not None: + assert len(ref_feat_shape) == 2 + ref_feat_shape_global = to_2tuple(ref_feat_shape[0]) + ref_feat_shape_window = to_2tuple(ref_feat_shape[1]) + else: + ref_feat_shape_global = ref_feat_shape_window = None + self.rope_global = RotaryEmbeddingCat( + embed_dim // num_heads, + in_pixels=False, + feat_shape=grid_size, + ref_feat_shape=ref_feat_shape_global, + ) + self.rope_window = RotaryEmbeddingCat( + embed_dim // num_heads, + in_pixels=False, + feat_shape=to_2tuple(window_size), + ref_feat_shape=ref_feat_shape_window, + ) + else: + self.rope_global = None + self.rope_window = None + # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.Sequential(*[ @@ -413,6 +466,7 @@ def __init__( use_rel_pos=use_rel_pos, window_size=window_size if i not in global_attn_indexes else 0, input_size=grid_size, + rope=self.rope_window if i not in global_attn_indexes else self.rope_global, ) for i in range(depth)]) @@ -436,7 +490,11 @@ def __init__( ) self.num_features = neck_chans else: - self.neck = nn.Identity() + if head_hidden_size: + self.neck = nn.Identity() + else: + # should have a final norm with standard ClassifierHead + self.neck = LayerNorm2d(embed_dim) neck_chans = embed_dim # Classifier Head @@ -526,7 +584,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 1024, 1024), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', **kwargs } @@ -552,6 +610,10 @@ def _cfg(url='', **kwargs): license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), + + 'samvit_base_patch16_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=1000, + input_size=(3, 224, 224), crop_pct=0.9), }) @@ -606,3 +668,17 @@ def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: model = _create_vision_transformer( 'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model + + +@register_model +def samvit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformerSAM: + """ ViT-B/16 based on samvit arch + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, global_attn_indexes=[2, 5, 8, 11], + window_size=14, use_rel_pos=True, use_abs_pos=False, img_size=224, neck_chans=None, + ) + model = _create_vision_transformer( + 'samvit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model +