Skip to content

做mask时候不用划分成9份吧, 4份就可以?附验证代码 #194

@jmjkx

Description

@jmjkx

本质上只要保证新窗口内的各个patch有来源的区分性就可以,作者通过 mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 来得到一个来源图,那完全可以划分成4份就可以了啊。

489d3d064a5c802c33e0e66c4a6ddde
这是验证代码

import torch


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows



window_size = 7
H, W = 56, 56
shift_size = window_size//2


img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

#### 以上划分9个窗口
####################################################################################################
#### 以下划分4个窗口

img_mask1 = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices1 = (slice(0, -shift_size),
            slice(-shift_size, None))
w_slices1 = (slice(0, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices1:
    for w in w_slices1:
        img_mask1[:, h, w, :] = cnt
        cnt += 1

mask_windows1 = window_partition(img_mask1, window_size)  # nW, window_size, window_size, 1
mask_windows1 = mask_windows1.view(-1, window_size * window_size)
attn_mask1 = mask_windows1.unsqueeze(1) - mask_windows1.unsqueeze(2)
attn_mask1 = attn_mask1.masked_fill(attn_mask1 != 0, float(-100.0)).masked_fill(attn_mask1 == 0, float(0.0))
t = attn_mask == attn_mask1
print(t.sum() == t.flatten(0).shape[0])

结果是true,是否说明直接划分四个区域就行了呢?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions