Attention is widely used in deep learning now. Given a query and a collection of key-value pairs, the output of an attention module is the weighted sum of all values. The weights are obtained based on the similarities between the query and keys which are usually measured by their inner products. However, when the number of keys is large, it is expensive to apply such a module.
Researchers consider local attention to address this problem. That is a small subset of keys is involved given a query. For images, "local" means an image region around a pixel. Image local attention achieves great success on image restoration tasks. However, current implementations are based on the im2col
operation which is memory expensive especially when the local patch is large.
Here, queries Q, keys K and value V are represented in CHW
(channel, height, width) tensors. They are generated by convolutions. And "local region" is a Ckk
sub tensor where k
is the size of a patch. Current implementations are based on the following steps:
- rearrange K and V to
(kk)CHW
tensors viaim2col
. - compute similarity matrix W between Q and K:
(kk)HW
. - compute output O by summation of V weighted by W:
CHW
.
Clearly, the first step requires kk
times memory to store the rearranged K and V. However, this can be avoided. In our implementation, we compute W and O without rearranging keys and values. To this end, we write two CUDA kernels. And we build a PyTorch extension based on them.
python setup.py install
Requirements:
PyTorch >= 1.4.0
CUDA >= 10.0
We write the Python warper in function.py
. Here is an example:
import torch
from function import LocalAttention
# kH and kW for local patch size
# works only on GPU
module = LocalAttention(inp_channels=3, out_channels=16, kH=7, kW=7).cuda()
x = torch.rand(32, 3, 64, 64).cuda()
# Q, K, V are generated by convolutions of x
y = module(x)
We evaluate the relative GPU memory and running time of our implementation compared with the plain PyTorch implementation: the first table for forward pass and the second table for forward-backward loop. Here, we set H=W=128
and C=64
.
k | Relative GPU Memory | Relative running time |
---|---|---|
5 | 10.2% | 31.4% |
11 | 3.2% | 15.6% |
21 | 2.0% | 26.5% |
k | Relative GPU Memory | Relative running time |
---|---|---|
5 | 9.0% | 31.2% |
11 | 3.4% | 21.5% |
21 | 2.3% | 47.3% |
Our implementation reduces the GPU memory by an order of magnitude and it is faster compared with the plain PyTorch implementations.
Refer /test
for more results.