Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attention实现的问题 #2

Closed
ddz16 opened this issue Nov 9, 2021 · 5 comments
Closed

attention实现的问题 #2

ddz16 opened this issue Nov 9, 2021 · 5 comments

Comments

@ddz16
Copy link

ddz16 commented Nov 9, 2021

您好,您提到的层次注意力是不是指的是band attention(如下图所示),只不过随着层数增加,窗口大小指数递增。这样的话model.py里这个函数里的那个for循环内容,是不是应该改为window_mask[:, i, i:i+self.bl] = 1

def construct_window_mask(self):
    window_mask = torch.zeros((1, self.bl, self.bl + 2* (self.bl //2)))
    for i in range(self.bl):
        window_mask[:, :, i:i+self.bl] = 1
    return window_mask.to(device)

image

@ChinaYi
Copy link
Owner

ChinaYi commented Nov 10, 2021

实现上不一样。假设序列长度为T,你的这种mask的形式首先需要做T x T的attention计算,然后再乘上mask,然而这种计算方法显存会爆炸;我采取的实现方式则是将图(b)的计算分成若干块,每次只需要保存2window_size x 2 window_size的attention value,极大的节省了显存。具体的细节你可以仔细看看 _sliding_window_self_att 函数的实现。祝好~
140908590-d3ac6a38-e899-4acd-aa9a-b7d84d5bcd4e

@ddz16
Copy link
Author

ddz16 commented Nov 10, 2021

您好,计算full attention再mask这样的计算方式会显存爆炸,你换了一种滑动窗的方式来实现,这一点我很赞同而且理解。

不过,我用你的 construct window mask函数出来的结果如下
image

我将函数进行微小的改动(就改了一个字符),window_mask[:, :, i:i+self.bl] = 1———>window_mask[:, i, i:i+self.bl] = 1,得到的结果如下
image

我觉得改动后的函数更符合你想达到的意图

不知道我的理解对不对

@ChinaYi
Copy link
Owner

ChinaYi commented Nov 10, 2021

或许你可以试试看改成window_mask[:, i, i:i+self.bl] = 1,看性能有无变化。我还没有仔细看,不过感觉你的理解应该没有错。现在的代码的写法有点像一种变相的滑窗,如下图所示,假设窗口size 为2:
140908590-d3ac6a38-e899-4acd-aa9a-b7d84d5bcd4e

我之前做过实验,无论是每一帧滑窗,还是直接把视频分成若干段,每段内部单独做attention,只要保证层级结构(窗口大小倍增或分段长度倍增),性能都是很高的。等闲下来我会仔细checkout一下。谢谢你的发现~

@ddz16
Copy link
Author

ddz16 commented Nov 10, 2021

是的,我也发现了这个现象,只要保证这种层级结构性能都会不错!感谢你的回复

@ddz16 ddz16 closed this as completed Nov 10, 2021
@ChinaYi
Copy link
Owner

ChinaYi commented Nov 10, 2021

哈哈哈哈哈。没错,这就是本文的重要发现,你用什么efficient version都无所谓,所以代码里面也提供了block-wise的做法,两者性能差不多,可视化的图也都一样。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants