-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move nn.glob.attention.GlobalAttention
to nn.aggr.attention.AttentionalAggregation
#4986
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4986 +/- ##
=======================================
Coverage 82.79% 82.79%
=======================================
Files 330 330
Lines 17978 17978
=======================================
Hits 14885 14885
Misses 3093 3093
Continue to review full report at Codecov.
|
test/nn/aggr/test_attention.py
Outdated
index = index.view(-1, 1).repeat(1, dim_size).view(-1) | ||
|
||
assert aggr(x, index).size() == (dim_size, channels) | ||
assert aggr(x, index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test
out = aggr(x, index)
assert out.size() == (dim_size, channels)
torch.allclose(aggr(x, index, dim_size=dim_size), out)
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this, but just to confirm you want to make surer that the output does not change if dim_size
is provided? Maybe it would be better to test
out = aggr(x, index)
assert out.size() == (dim_size, channels)
torch.allclose(aggr(x, index, dim_size=dim_size + 1)[:3], out)
WDYT?
Thanks for the comments @rusty1s! |
nn.glob.attention.GlobalAttention
to nn.aggr.attention.AttentionAggregation
nn.glob.attention.GlobalAttention
to nn.aggr.attention.AttentionalAggregation
Addresses #4712