forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Add pooling operator from Poolformer (facebookresearch#297)
- Loading branch information
1 parent
90a1ec7
commit 7705e5e
Showing
11 changed files
with
107 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import math | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from xformers.components.attention import Attention, AttentionConfig, register_attention | ||
|
||
|
||
@dataclass | ||
class PoolingAttentionConfig(AttentionConfig): | ||
pool_size: int # dimension of the input sequence | ||
stride: Optional[int] # dimension of the internal space | ||
padding: Optional[int] | ||
|
||
|
||
@register_attention("pooling", PoolingAttentionConfig) | ||
class Pooling(Attention): | ||
def __init__( | ||
self, | ||
pool_size: int = 3, | ||
stride: int = 1, | ||
padding: Optional[int] = None, | ||
*_, | ||
**__, | ||
): | ||
""" | ||
Pooling token mixing mechanism, as proposed in | ||
`Metaformer is actually what you need for vision`_, Yu et al (2021). | ||
The original notation is kept as is. | ||
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf | ||
""" | ||
super().__init__() | ||
|
||
padding = padding if padding is not None else pool_size // 2 | ||
self.pool = nn.AvgPool2d( | ||
pool_size, | ||
stride=stride, | ||
padding=pool_size // 2, | ||
count_include_pad=False, | ||
) | ||
|
||
# MHA related flags: | ||
# kq need to have the same dimension | ||
self.requires_same_k_q_dimensions = False | ||
|
||
# This attention does not support attention masks | ||
self.supports_attention_mask = False | ||
|
||
# This "attention" (token mixing) skips the multihead attention altogether | ||
self.requires_skip_multi_head = True | ||
|
||
# This operator does not really handle q,k,v | ||
self.requires_same_k_q_dimensions = True | ||
|
||
def forward(self, q: torch.Tensor, *_, **__): | ||
# Expose the 2D token structure | ||
B, HW, C = q.shape | ||
H = int(math.sqrt(HW)) | ||
assert H * H == HW | ||
|
||
q = q.transpose(-2, -1).reshape(B, C, H, H) | ||
|
||
# 2D pool | ||
x_pool = self.pool(q) - q # compensate for the residual path | ||
|
||
# Get back to B HW C | ||
return x_pool.flatten(2, 3).transpose(-2, -1) |