diff --git a/README.md b/README.md index d2403c7..dabf3f2 100644 --- a/README.md +++ b/README.md @@ -108,9 +108,12 @@ unet = SpaceTimeUnet( dim = 64, channels = 3, dim_mult = (1, 2, 4, 8), + resnet_block_depths = (1, 1, 1, 2), temporal_compression = (False, False, False, True), self_attns = (False, False, False, True), - condition_on_timestep = False + condition_on_timestep = False, + attn_pos_bias = False, + flash_attn = True ).cuda() # train on images @@ -136,8 +139,8 @@ video_as_images_out = unet(video, enable_time = False) - [x] give attention the best positional embeddings research has to offer - [x] soup up the attention +- [x] add flash attention -- [ ] add flash attention - [ ] make sure dalle2-pytorch can accept `SpaceTimeUnet` for training ## Citations @@ -187,3 +190,12 @@ video_as_images_out = unet(video, enable_time = False) url = {https://openreview.net/forum?id=GMYWzWztDx5}, } ``` + +```bibtex +@inproceedings{dao2022flashattention, + title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle = {Advances in Neural Information Processing Systems}, + year = {2022} +} +``` diff --git a/make_a_video_pytorch/attend.py b/make_a_video_pytorch/attend.py new file mode 100644 index 0000000..ef23a23 --- /dev/null +++ b/make_a_video_pytorch/attend.py @@ -0,0 +1,114 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange + +# constants + +AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False + ): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = AttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention if input tensor is on cuda') + self.cuda_config = AttentionConfig(True, False, False) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = AttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda + + q, k, v = map(lambda t: t.contiguous(), (q, k, v)) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v, bias = None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + if self.flash: + assert not exists(bias) + return self.flash_attn(q, k, v) + + scale = q.shape[-1] ** -0.5 + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim = -1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index b7bdb57..cf4ba49 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -9,6 +9,8 @@ from einops import rearrange, repeat, pack, unpack from einops.layers.torch import Rearrange +from make_a_video_pytorch.attend import Attend + # helper functions def exists(val): @@ -45,16 +47,17 @@ def forward(self, x): # layernorm 3d -class ChanLayerNorm(nn.Module): - def __init__(self, dim): +class RMSNorm(nn.Module): + def __init__(self, chan, dim = 1): super().__init__() - self.g = nn.Parameter(torch.ones(dim, 1, 1, 1)) + self.dim = dim + self.gamma = nn.Parameter(torch.ones(chan)) def forward(self, x): - eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) * var.clamp(min = eps).rsqrt() * self.g + dim = self.dim + right_ones = (dim + 1) if dim < 0 else (x.ndim - 1 - dim) + gamma = self.gamma.reshape(-1, *((1,) * right_ones)) + return F.normalize(x, dim = dim) * (x.shape[dim] ** 0.5) * gamma # feedforward @@ -79,7 +82,7 @@ def __init__(self, dim, mult = 4): ) self.proj_out = nn.Sequential( - ChanLayerNorm(inner_dim), + RMSNorm(inner_dim), nn.Conv3d(inner_dim, dim, 1, bias = False) ) @@ -180,14 +183,17 @@ def __init__( self, dim, dim_head = 64, - heads = 8 + heads = 8, + flash = False ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 inner_dim = dim_head * heads - self.norm = nn.LayerNorm(dim) + self.attend = Attend(flash = flash) + + self.norm = RMSNorm(dim, dim = -1) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -200,23 +206,13 @@ def forward( x, rel_pos_bias = None ): - x = self.norm(x) q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) - q = q * self.scale - - sim = einsum('b h i d, b h j d -> b h i j', q, k) - - if exists(rel_pos_bias): - sim = sim + rel_pos_bias - - attn = sim.softmax(dim = -1) - - out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = self.attend(q, k, v, bias = rel_pos_bias) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) @@ -283,14 +279,18 @@ def __init__( dim_head = 64, heads = 8, add_feed_forward = True, - ff_mult = 4 + ff_mult = 4, + pos_bias = True, + flash = False ): super().__init__() - self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) - self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) + assert not (flash and pos_bias), 'learned positional attention bias is not compatible with flash attention' + + self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash) + self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) if pos_bias else None - self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) - self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) + self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash) + self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) if pos_bias else None self.has_feed_forward = add_feed_forward if not add_feed_forward: @@ -312,7 +312,7 @@ def forward( else: x = rearrange(x, 'b c h w -> b (h w) c') - space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) + space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) if exists(self.spatial_rel_pos_bias) else None x = self.spatial_attn(x, rel_pos_bias = space_rel_pos_bias) + x @@ -325,7 +325,7 @@ def forward( x = rearrange(x, 'b c f h w -> (b h w) f c') - time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) + time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) if exists(self.temporal_rel_pos_bias) else None x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x @@ -543,7 +543,9 @@ def __init__( resnet_block_depths = (2, 2, 2, 2), attn_dim_head = 64, attn_heads = 8, - condition_on_timestep = True + condition_on_timestep = True, + attn_pos_bias = True, + flash_attn = False ): super().__init__() assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths) @@ -576,13 +578,15 @@ def __init__( attn_kwargs = dict( dim_head = attn_dim_head, - heads = attn_heads + heads = attn_heads, + pos_bias = attn_pos_bias, + flash= flash_attn ) mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim) - self.mid_attn = SpatioTemporalAttention(dim = mid_dim) + self.mid_attn = SpatioTemporalAttention(dim = mid_dim, **attn_kwargs) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim) for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns, dim_in_out, temporal_compression, resnet_block_depths): diff --git a/setup.py b/setup.py index 412dbb2..ae9c875 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'make-a-video-pytorch', packages = find_packages(exclude=[]), - version = '0.2.1', + version = '0.3.0', license='MIT', description = 'Make-A-Video - Pytorch', author = 'Phil Wang',