Skip to content

Commit 2ab9db4

Browse files
Update for mask operation
1 parent 8a7b783 commit 2ab9db4

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

Fastformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import einops
2+
from einops import rearrange
23
import torch
34
import torch.nn as nn
45

@@ -21,9 +22,13 @@ def forward(self, x, mask = None):
2122
value = self.weight_v(x)
2223
b, n, d = query.shape
2324

25+
mask_value = torch.finfo(x.dtype).min
26+
mask = rearrange(mask, 'b n -> b () n')
27+
2428
# Caculate the global query
2529
alpha_weight = torch.softmax(torch.mul(query, self.weight_alpha) * self.scale_factor, dim = -1)
2630
global_query = query * alpha_weight
31+
global_query = global_query.masked_fill(~mask, mask_value)
2732
global_query = torch.einsum('b n d -> b d', global_query)
2833

2934
# Model the interaction between global query vector and the key vector
@@ -42,5 +47,6 @@ def forward(self, x, mask = None):
4247
if __name__ == '__main__':
4348
model = Fastformer(dim = 3, decode_dim = 8)
4449
x = torch.randn(4, 6, 3)
45-
result = model(x)
50+
mask = torch.ones(1, 8).bool()
51+
result = model(x, mask)
4652
print(result.size())

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import Fastformer
1212

1313
model = Fastformer(dim = 3, decode_dim = 8)
1414
x = torch.randn(4, 6, 3)
15-
result = model(x)
15+
mask = torch.ones(1, 8).bool()
16+
result = model(x, mask)
1617
print(result.size())
1718
```
1819

0 commit comments

Comments
 (0)