-
Notifications
You must be signed in to change notification settings - Fork 96
Open
Description
在attention代码中,我发现有一个名为to_out的操作,我无法理解这个操作是用来实现什么功能的
具体代码为:
class Attention(nn.Module):
def init(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().init()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.qkv = nn.Linear(dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
w = rearrange(self.qkv(x), 'b n (h d) -> b h n d', h = self.heads)
dots = torch.matmul(w, w.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, w)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
在forward得到结果后,最后输出是使用了to_out()操作,但是我在对应的MSSA部分没有找到相应的理论依据,请问可以麻烦解释一下吗
同时,在MSSA模块之前的LayerNorm和ISTA之前的LayerNorm是在代码中的哪部分实现的呢,我没有找到相应的代码
BiEchi, HuazeTang, Ludwing, QiaoranC, laiyingxin2 and 1 more
Metadata
Metadata
Assignees
Labels
No labels