Skip to content

Commit

Permalink
Changed utils.py; Added window_utils.py and merge.py (#21)
Browse files Browse the repository at this point in the history
* changes in utils folder; added window_utils.py and merge.py

* Update vformer/utils/window_utils.py

* Update vformer/functional/merge.py

Co-authored-by: Neelay Shah <shahnh19@gmail.com>
  • Loading branch information
abhi-glitchhg and NeelayS authored Oct 11, 2021
1 parent 4f8b91a commit 00eec73
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 164 deletions.
2 changes: 1 addition & 1 deletion tests/test_encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from vformer.encoder import SwinEncoder, SwinEncoderBlock, VanillaEncoder
from vformer.utils import PatchMerging
from vformer.functional import PatchMerging

test_tensor1 = torch.randn(2, 65, 1024)
test_tensor2 = torch.randn(3, 3136, 96)
Expand Down
1 change: 0 additions & 1 deletion vformer/encoder/embedding/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn

Expand Down
8 changes: 1 addition & 7 deletions vformer/encoder/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
from timm.models.layers import DropPath

from ..attention.window import WindowAttention
from ..utils.utils import (
create_mask,
cyclicshift,
pair,
window_partition,
window_reverse,
)
from ..utils import create_mask, cyclicshift, pair, window_partition, window_reverse
from .nn import FeedForward


Expand Down
3 changes: 2 additions & 1 deletion vformer/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .norm import *
from .merge import PatchMerging
from .norm import PreNorm
41 changes: 41 additions & 0 deletions vformer/functional/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn

from ..utils import pair


class PatchMerging(nn.Module):
"""
Parameters :
------------
input_resolution: int or tuple[int]
Resolution of input features
dim : int
"""

def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super(PatchMerging, self).__init__()
self.input_resolution = pair(input_resolution)
self.dim = dim
self.reduction = nn.Linear(4 * self.dim, 2 * self.dim, bias=False)
self.norm = norm_layer(4 * dim)

def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

x = x.view(B, H, W, C)

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C

x = self.norm(x)
x = self.reduction(x)

return x
2 changes: 1 addition & 1 deletion vformer/models/classification/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ...common import BaseClassificationModel
from ...decoder import MLPDecoder
from ...encoder import PatchEmbedding, SwinEncoder
from ...utils import PatchMerging
from ...functional import PatchMerging


class SwinTransformer(BaseClassificationModel):
Expand Down
3 changes: 2 additions & 1 deletion vformer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .utils import *
from .utils import pair
from .window_utils import *
152 changes: 0 additions & 152 deletions vformer/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,159 +1,7 @@
import math
import warnings

import torch
import torch.nn as nn


def pair(t):
"""
Parameters:
-----------
t: tuple[int] or int
"""
return t if isinstance(t, tuple) else (t, t)


def cyclicshift(input, shift_size, dims=None):
"""
Parameters:
----------
input: torch.Tensor
input tensor
shift_size: int or tuple[int]
Number of places by which input tensor is shifted
dims: int or tuple[int],optional
Axis along which to roll
"""

return torch.roll(
input, shifts=pair(shift_size), dims=(1, 2) if dims == None else dims
)


class PatchMerging(nn.Module):
"""
Parameters :
------------
input_resolution: int or tuple[int]
Resolution of input features
dim : int
"""

def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super(PatchMerging, self).__init__()
self.input_resolution = pair(input_resolution)
self.dim = dim
self.reduction = nn.Linear(4 * self.dim, 2 * self.dim, bias=False)
self.norm = norm_layer(4 * dim)

def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

x = x.view(B, H, W, C)

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C

x = self.norm(x)
x = self.reduction(x)

return x


def window_partition(x, window_size):
"""
Parameters:
-----------
x: torch.Tensor
input tensor
window_size: int
window size
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)

return windows


def window_reverse(windows, window_size, H, W):
"""
Parameters:
-----------
windows: torch.Tensor
window_size: int
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x


def get_relative_position_bias_index(window_size):
"""
Parameters:
------------
window_size: int or tuple[int]
Window size
"""
window_size = pair(window_size)
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
return relative_position_index


def create_mask(window_size, shift_size, H, W):
"""
Parameters:
-----------
window_size: int
Window Size
shift_size: int
Shift_size
"""
img_mask = torch.zeros(1, H, W, 1)
h_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
w_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size)
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
111 changes: 111 additions & 0 deletions vformer/utils/window_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch

from .utils import pair


def cyclicshift(input, shift_size, dims=None):
"""
Parameters:
----------
input: torch.Tensor
input tensor
shift_size: int or tuple[int]
Number of places by which input tensor is shifted
dims: int or tuple[int],optional
Axis along which to roll
"""

return torch.roll(
input, shifts=pair(shift_size), dims=(1, 2) if dims == None else dims
)


def window_partition(x, window_size):
"""
Parameters:
-----------
x: torch.Tensor
input tensor
window_size: int
window size
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)

return windows


def window_reverse(windows, window_size, H, W):
"""
Parameters:
-----------
windows: torch.Tensor
window_size: int
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x


def get_relative_position_bias_index(window_size):
"""
Parameters:
------------
window_size: int or tuple[int]
Window size
"""
window_size = pair(window_size)
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
return relative_position_index


def create_mask(window_size, shift_size, H, W):
"""
Parameters:
-----------
window_size: int
Window Size
shift_size: int
Shift_size
"""
img_mask = torch.zeros(1, H, W, 1)
h_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
w_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size)
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask

0 comments on commit 00eec73

Please sign in to comment.