Skip to content

Commit

Permalink
[feat] Add pooling operator from Poolformer (facebookresearch#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored May 9, 2022
1 parent 90a1ec7 commit 7705e5e
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 20 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- *[FNet: Mixing Tokens with Fourier Transforms, Lee-Thorp et al.](https://arxiv.org/abs/2105.03824v1)*
- [CompositionalAttention](xformers/components/attention/compositional.py)
- *[Compositional Attention: Disentangling search and retrieval, S. Mittal et al.](https://arxiv.org/pdf/2110.09419v1.pdf)*
- [2D Pooling](xformers/components/attention/pooling.py)
- *[Metaformer is actually what you need for vision, Yu et al.](https://arxiv.org/pdf/2111.11418v1.pdf)*

- ... add a new one [see Contribution.md](CONTRIBUTING.md)

Expand Down
Binary file modified docs/plots/memory_vs_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/runtime_vs_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 4 additions & 5 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,7 @@ def forward(self, x):
x = self.trunk(x)
x = self.ln(x)

if self.hparams.classifier == Classifier.TOKEN:
x = x[:, 0] # only consider the token, we're classifying anyway
elif self.hparams.classifier == Classifier.GAP:
x = x.mean(dim=1) # mean over sequence len

x = x.mean(dim=1) # mean over sequence len
x = self.head(x)
return x

Expand All @@ -129,6 +125,9 @@ def forward(self, x):
NUM_WORKERS = 4
GPUS = 1

torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

train_transforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
Expand Down
11 changes: 3 additions & 8 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_order_invariance(
device: torch.device,
):

if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "poolling":
if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

torch.manual_seed(42)
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_broadcast_batch_dimension(
device: torch.device,
batch_sizes: Tuple[int, int, int],
):
if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "poolling":
if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

Q_BATCH, K_BATCH, V_BATCH = batch_sizes
Expand Down Expand Up @@ -370,12 +370,7 @@ def test_torch_script_ability(
heads: int,
attn_dropout: float,
):
if attention_name in {
"favor",
"global",
"local",
"random",
}:
if attention_name in {"favor", "global", "local", "random", "pooling"}:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support scripting yet.")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_hierarchical_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def test_hierarchical_transformer():
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
attention_mechanism="scaled_dot_product",
attention_mechanism="pooling",
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism="scaled_dot_product",
attention_mechanism="pooling",
patch_size=3,
stride=2,
padding=1,
Expand Down
3 changes: 1 addition & 2 deletions xformers/benchmarks/LRA/run_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import itertools
import os
import uuid
from collections import Iterable
from datetime import date
from pathlib import Path
from typing import Dict
from typing import Dict, Iterable

import submitit

Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def plot(args, results: List[Dict[str, Any]]):
"-emb", "--embedding_dim", nargs="+", default=[64, 128, 256], type=int
)
parser.add_argument(
"-sl", "--sequence_length", nargs="+", default=[512, 1024], type=int
"-sl", "--sequence_length", nargs="+", default=[576, 1024], type=int
)
parser.add_argument("-bs", "--batch_size", nargs="+", default=[8, 16, 32], type=int)
parser.add_argument("-heads", "--heads", nargs="+", default=[8, 16], type=int)
Expand Down
18 changes: 17 additions & 1 deletion xformers/components/attention/csrc/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,23 @@ __global__ void attention_kernel(
}

// Computes s_prime, buffer (aka v_prime) and m_prime
UnrollLoop<true, scalar_t, vec_t, kBlockSizeK, kBlockSizeQ, BUFFER_SIZE, WARP_SIZE>::eval(query_block, key[batch_idx], value[batch_idx], m_prime, s_prime, buffer, K, N);
UnrollLoop<
true,
scalar_t,
vec_t,
kBlockSizeK,
kBlockSizeQ,
BUFFER_SIZE,
WARP_SIZE>::
eval(
query_block,
key[batch_idx],
value[batch_idx],
m_prime,
s_prime,
buffer,
K,
N);

aggregate_coeffs<scalar_t, vec_t, kBlockSizeQ, WARP_SIZE, BUFFER_SIZE>(
m_prime, s_prime, buffer, K);
Expand Down
1 change: 0 additions & 1 deletion xformers/components/attention/feature_maps/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
):
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
self.softmax_temp = softmax_temp
self.offset = -1.0

# Handle the scaling from all kernels by √m.
# This normalizes for all the feature maps involved
Expand Down
77 changes: 77 additions & 0 deletions xformers/components/attention/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn

from xformers.components.attention import Attention, AttentionConfig, register_attention


@dataclass
class PoolingAttentionConfig(AttentionConfig):
pool_size: int # dimension of the input sequence
stride: Optional[int] # dimension of the internal space
padding: Optional[int]


@register_attention("pooling", PoolingAttentionConfig)
class Pooling(Attention):
def __init__(
self,
pool_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
*_,
**__,
):
"""
Pooling token mixing mechanism, as proposed in
`Metaformer is actually what you need for vision`_, Yu et al (2021).
The original notation is kept as is.
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf
"""
super().__init__()

padding = padding if padding is not None else pool_size // 2
self.pool = nn.AvgPool2d(
pool_size,
stride=stride,
padding=pool_size // 2,
count_include_pad=False,
)

# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = False

# This attention does not support attention masks
self.supports_attention_mask = False

# This "attention" (token mixing) skips the multihead attention altogether
self.requires_skip_multi_head = True

# This operator does not really handle q,k,v
self.requires_same_k_q_dimensions = True

def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW

q = q.transpose(-2, -1).reshape(B, C, H, H)

# 2D pool
x_pool = self.pool(q) - q # compensate for the residual path

# Get back to B HW C
return x_pool.flatten(2, 3).transpose(-2, -1)

0 comments on commit 7705e5e

Please sign in to comment.