-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Description
本质上只要保证新窗口内的各个patch有来源的区分性就可以,作者通过 mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 来得到一个来源图,那完全可以划分成4份就可以了啊。
import torch
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 7
H, W = 56, 56
shift_size = window_size//2
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
#### 以上划分9个窗口
####################################################################################################
#### 以下划分4个窗口
img_mask1 = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices1 = (slice(0, -shift_size),
slice(-shift_size, None))
w_slices1 = (slice(0, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices1:
for w in w_slices1:
img_mask1[:, h, w, :] = cnt
cnt += 1
mask_windows1 = window_partition(img_mask1, window_size) # nW, window_size, window_size, 1
mask_windows1 = mask_windows1.view(-1, window_size * window_size)
attn_mask1 = mask_windows1.unsqueeze(1) - mask_windows1.unsqueeze(2)
attn_mask1 = attn_mask1.masked_fill(attn_mask1 != 0, float(-100.0)).masked_fill(attn_mask1 == 0, float(0.0))
t = attn_mask == attn_mask1
print(t.sum() == t.flatten(0).shape[0])
结果是true,是否说明直接划分四个区域就行了呢?
oneengineer and fuyw
Metadata
Metadata
Assignees
Labels
No labels