Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClipQKV #197

Merged
merged 11 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 79 additions & 27 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
except ImportError as e:
raise e

self.clip_qkv = cfg.get('attn_clip_qkv')
self.attn_qk_ln = cfg.get('attn_qk_ln')
self.d_model = cfg.d_model
self.n_heads = cfg.n_heads

if self.attn_qk_ln:
if self.attn_qk_ln or self.clip_qkv:
self.W_qkv = nn.Linear(self.d_model,
3 * self.d_model,
bias=True,
Expand All @@ -138,8 +139,9 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):

self.out_proj._is_residual = True # type: ignore

self.q_ln = nn.LayerNorm(self.d_model, device=device)
self.k_ln = nn.LayerNorm(self.d_model, device=device)
if self.attn_qk_ln:
self.q_ln = nn.LayerNorm(self.d_model, device=device)
self.k_ln = nn.LayerNorm(self.d_model, device=device)
else:
self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
Expand All @@ -154,15 +156,18 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):

def forward(self, x, key_padding_mask, attn_mask=None):
assert attn_mask is None
if self.attn_qk_ln:
qkv = self.W_qkv(x)

# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype)
k = self.k_ln(k).to(dtype)
qkv = torch.cat([q, k, v], dim=-1)
if self.attn_qk_ln or self.clip_qkv:
qkv = self.W_qkv(x)
if self.clip_qkv:
qkv = qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
if self.attn_qk_ln:
# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype)
k = self.k_ln(k).to(dtype)
qkv = torch.cat([q, k, v], dim=-1)

# attention
qkv = rearrange(qkv,
Expand Down Expand Up @@ -213,22 +218,46 @@ class TritonFlashCausalAttention(nn.Module):
def __init__(self, cfg: DictConfig, device: Optional[str] = None):
super().__init__()
try:
from examples.llm.src.models.layers.flash_attention import \
FlashMHA # type: ignore
from examples.llm.src.models.layers.flash_attention import ( # type: ignore
FlashAttention, FlashMHA)
except ImportError as e:
raise e

assert cfg.attn_pdrop == 0, 'triton kernel does not support attn_dropout'

self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
bias=True,
batch_first=True,
causal=True,
device=device,
)
self.mhsa.out_proj._is_residual = True # type: ignore
self.clip_qkv = cfg.get('attn_clip_qkv')
self.attn_qk_ln = cfg.get('attn_qk_ln')
self.d_model = cfg.d_model
self.n_heads = cfg.n_heads

if self.attn_qk_ln or self.clip_qkv:
self.Wqkv = nn.Linear(self.d_model,
3 * self.d_model,
bias=True,
device=device)
self.inner_attn = FlashAttention(num_heads=cfg.n_heads,
vchiley marked this conversation as resolved.
Show resolved Hide resolved
softmax_scale=None,
vchiley marked this conversation as resolved.
Show resolved Hide resolved
device=device)
self.out_proj = nn.Linear(self.d_model,
self.d_model,
bias=True,
device=device)

self.out_proj._is_residual = True # type: ignore

if self.attn_qk_ln:
self.q_ln = nn.LayerNorm(self.d_model, device=device)
self.k_ln = nn.LayerNorm(self.d_model, device=device)
else:
self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
bias=True,
batch_first=True,
causal=True,
device=device,
)
self.mhsa.out_proj._is_residual = True # type: ignore

warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` '
Expand All @@ -237,11 +266,34 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
'using `attn_impl: flash`.')

def forward(self, x, key_padding_mask=None, attn_mask=None):
assert key_padding_mask is None
return self.mhsa(x,
key_padding_mask=None,
attn_mask=attn_mask,
need_weights=False)
if self.attn_qk_ln or self.clip_qkv:
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv = qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
vchiley marked this conversation as resolved.
Show resolved Hide resolved
if self.attn_qk_ln:
# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype)
k = self.k_ln(k).to(dtype)
qkv = torch.cat([q, k, v], dim=-1)

# attention
context, attn_weights = self.inner_attn(
qkv,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
is_causal=True,
need_weights=False,
average_attn_weights=False)

return self.out_proj(context), attn_weights

else:
return self.mhsa(x,
key_padding_mask=None,
attn_mask=attn_mask,
need_weights=False)

@staticmethod
def mask_shape(n_heads, seq_len, alibi):
Expand Down
9 changes: 7 additions & 2 deletions examples/llm/src/models/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,14 @@ def __init__(self, cfg: DictConfig):
else:
raise ValueError(f'Unknown attn_impl={cfg.attn_impl}')

if cfg.get('attn_qk_ln') and cfg.attn_impl != 'flash':
if cfg.get('attn_qk_ln') and not cfg.attn_impl in ['flash', 'triton']:
vchiley marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
'LayerNorm over queries and keys in attention is only implemented with flash attention.'
'LayerNorm over queries and keys in attention is only implemented with flash and triton attention.'
)
if cfg.get(
'attn_clip_qkv') and not cfg.attn_impl in ['flash', 'triton']:
raise NotImplementedError(
'QKV clipping only implemented with flash and triton attention.'
)

self.alibi = cfg.get('alibi', False)
Expand Down