Skip to content

Commit

Permalink
add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 19, 2023
1 parent 57ed18b commit f2d214f
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 34 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,12 @@ unet = SpaceTimeUnet(
dim = 64,
channels = 3,
dim_mult = (1, 2, 4, 8),
resnet_block_depths = (1, 1, 1, 2),
temporal_compression = (False, False, False, True),
self_attns = (False, False, False, True),
condition_on_timestep = False
condition_on_timestep = False,
attn_pos_bias = False,
flash_attn = True
).cuda()

# train on images
Expand All @@ -136,8 +139,8 @@ video_as_images_out = unet(video, enable_time = False)

- [x] give attention the best positional embeddings research has to offer
- [x] soup up the attention
- [x] add flash attention

- [ ] add flash attention
- [ ] make sure dalle2-pytorch can accept `SpaceTimeUnet` for training

## Citations
Expand Down Expand Up @@ -187,3 +190,12 @@ video_as_images_out = unet(video, enable_time = False)
url = {https://openreview.net/forum?id=GMYWzWztDx5},
}
```

```bibtex
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```
114 changes: 114 additions & 0 deletions make_a_video_pytorch/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# constants

AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def exists(val):
return val is not None

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# main class

class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = AttentionConfig(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not flash:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(False, True, True)

def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

q, k, v = map(lambda t: t.contiguous(), (q, k, v))

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)

return out

def forward(self, q, k, v, bias = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

if self.flash:
assert not exists(bias)
return self.flash_attn(q, k, v)

scale = q.shape[-1] ** -0.5

# similarity

sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

# attention

attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)

# aggregate values

out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

return out
66 changes: 35 additions & 31 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

from make_a_video_pytorch.attend import Attend

# helper functions

def exists(val):
Expand Down Expand Up @@ -45,16 +47,17 @@ def forward(self, x):

# layernorm 3d

class ChanLayerNorm(nn.Module):
def __init__(self, dim):
class RMSNorm(nn.Module):
def __init__(self, chan, dim = 1):
super().__init__()
self.g = nn.Parameter(torch.ones(dim, 1, 1, 1))
self.dim = dim
self.gamma = nn.Parameter(torch.ones(chan))

def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * var.clamp(min = eps).rsqrt() * self.g
dim = self.dim
right_ones = (dim + 1) if dim < 0 else (x.ndim - 1 - dim)
gamma = self.gamma.reshape(-1, *((1,) * right_ones))
return F.normalize(x, dim = dim) * (x.shape[dim] ** 0.5) * gamma

# feedforward

Expand All @@ -79,7 +82,7 @@ def __init__(self, dim, mult = 4):
)

self.proj_out = nn.Sequential(
ChanLayerNorm(inner_dim),
RMSNorm(inner_dim),
nn.Conv3d(inner_dim, dim, 1, bias = False)
)

Expand Down Expand Up @@ -180,14 +183,17 @@ def __init__(
self,
dim,
dim_head = 64,
heads = 8
heads = 8,
flash = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads

self.norm = nn.LayerNorm(dim)
self.attend = Attend(flash = flash)

self.norm = RMSNorm(dim, dim = -1)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
Expand All @@ -200,23 +206,13 @@ def forward(
x,
rel_pos_bias = None
):

x = self.norm(x)

q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

q = q * self.scale

sim = einsum('b h i d, b h j d -> b h i j', q, k)

if exists(rel_pos_bias):
sim = sim + rel_pos_bias

attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = self.attend(q, k, v, bias = rel_pos_bias)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Expand Down Expand Up @@ -283,14 +279,18 @@ def __init__(
dim_head = 64,
heads = 8,
add_feed_forward = True,
ff_mult = 4
ff_mult = 4,
pos_bias = True,
flash = False
):
super().__init__()
self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2)
assert not (flash and pos_bias), 'learned positional attention bias is not compatible with flash attention'

self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) if pos_bias else None

self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1)
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) if pos_bias else None

self.has_feed_forward = add_feed_forward
if not add_feed_forward:
Expand All @@ -312,7 +312,7 @@ def forward(
else:
x = rearrange(x, 'b c h w -> b (h w) c')

space_rel_pos_bias = self.spatial_rel_pos_bias(h, w)
space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) if exists(self.spatial_rel_pos_bias) else None

x = self.spatial_attn(x, rel_pos_bias = space_rel_pos_bias) + x

Expand All @@ -325,7 +325,7 @@ def forward(

x = rearrange(x, 'b c f h w -> (b h w) f c')

time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) if exists(self.temporal_rel_pos_bias) else None

x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x

Expand Down Expand Up @@ -543,7 +543,9 @@ def __init__(
resnet_block_depths = (2, 2, 2, 2),
attn_dim_head = 64,
attn_heads = 8,
condition_on_timestep = True
condition_on_timestep = True,
attn_pos_bias = True,
flash_attn = False
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
Expand Down Expand Up @@ -576,13 +578,15 @@ def __init__(

attn_kwargs = dict(
dim_head = attn_dim_head,
heads = attn_heads
heads = attn_heads,
pos_bias = attn_pos_bias,
flash= flash_attn
)

mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim)
self.mid_attn = SpatioTemporalAttention(dim = mid_dim)
self.mid_attn = SpatioTemporalAttention(dim = mid_dim, **attn_kwargs)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim)

for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns, dim_in_out, temporal_compression, resnet_block_depths):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.1',
version = '0.3.0',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f2d214f

Please sign in to comment.