22from functools import partial
33from torch import nn , einsum
44from torch .utils .checkpoint import checkpoint
5+ import torch .nn .functional as F
56
67from einops import rearrange
78
@@ -49,7 +50,7 @@ def attention(
4950
5051# memory efficient attention
5152
52- def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices ):
53+ def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices , dropout = 0. , training = False ):
5354 q_start_index , k_start_index , q_chunk_size , k_chunk_size , device = * qk_start_indices , q .shape [- 2 ], k .shape [- 2 ], q .device
5455
5556 weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
@@ -71,6 +72,8 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
7172 weight = weight - weight_max
7273
7374 exp_weight = weight .exp ()
75+ if training :
76+ exp_weight = F .dropout (exp_weight , p = dropout , training = training )
7477 weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
7578
7679 return exp_weight .sum (dim = - 1 ), weighted_value , rearrange (weight_max , '... 1 -> ...' )
@@ -84,7 +87,9 @@ def memory_efficient_attention(
8487 attn_bias = None ,
8588 q_bucket_size = 512 ,
8689 k_bucket_size = 1024 ,
87- eps = 1e-8
90+ eps = 1e-8 ,
91+ dropout = 0. ,
92+ training = False
8893):
8994 scale = q .shape [- 1 ] ** - 0.5
9095 q = q * scale
@@ -131,7 +136,9 @@ def memory_efficient_attention(
131136 mask_chunk ,
132137 attn_bias_chunk ,
133138 causal ,
134- (q_start_index , k_start_index )
139+ (q_start_index , k_start_index ),
140+ dropout = dropout ,
141+ training = training
135142 )
136143
137144 exp_weights .append (exp_weight_chunk )
@@ -175,7 +182,7 @@ def __init__(
175182 super ().__init__ ()
176183 self .heads = heads
177184 self .causal = causal
178-
185+ self . dropout = dropout
179186 inner_dim = heads * dim_head
180187
181188 self .to_q = nn .Linear (dim , inner_dim , bias = False )
@@ -212,7 +219,8 @@ def forward(
212219
213220 attn_fn = attention if not memory_efficient else memory_efficient_attention
214221
215- out = attn_fn (q , k , v , mask = mask , attn_bias = attn_bias , causal = self .causal , q_bucket_size = q_bucket_size , k_bucket_size = k_bucket_size )
222+ out = attn_fn (q , k , v , mask = mask , attn_bias = attn_bias , causal = self .causal , q_bucket_size = q_bucket_size ,
223+ k_bucket_size = k_bucket_size , dropout = self .dropout , training = self .training )
216224
217225 out = rearrange (out , 'b h n d -> b n (h d)' )
218226 return self .to_out (out )
0 commit comments