Skip to content

Commit

Permalink
ClipQKV (#197)
Browse files Browse the repository at this point in the history
* add clip to attn and add ln_qk to triton kernel

* Update examples/llm/src/models/mosaic_gpt.py

Co-authored-by: Abhi Venigalla <77638579+abhi-mosaic@users.noreply.github.com>

* abhi review suggestions

* add attn tests

* dk pr cmts

---------

Co-authored-by: Abhi Venigalla <77638579+abhi-mosaic@users.noreply.github.com>
  • Loading branch information
vchiley and abhi-mosaic authored Mar 1, 2023
1 parent 27917b5 commit e63c50c
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 34 deletions.
114 changes: 84 additions & 30 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,28 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
except ImportError as e:
raise e

self.clip_qkv = cfg.get('attn_clip_qkv')
self.attn_qk_ln = cfg.get('attn_qk_ln')
self.d_model = cfg.d_model
self.n_heads = cfg.n_heads

if self.attn_qk_ln:
if self.attn_qk_ln or self.clip_qkv:
self.W_qkv = nn.Linear(self.d_model,
3 * self.d_model,
bias=True,
device=device)
self.causal_attn = FlashAttention(attention_dropout=cfg.attn_pdrop,
device=device)
self.inner_attn = FlashAttention(attention_dropout=cfg.attn_pdrop,
device=device)
self.out_proj = nn.Linear(self.d_model,
self.d_model,
bias=True,
device=device)

self.out_proj._is_residual = True # type: ignore

self.q_ln = nn.LayerNorm(self.d_model, device=device)
self.k_ln = nn.LayerNorm(self.d_model, device=device)
if self.attn_qk_ln:
self.q_ln = nn.LayerNorm(self.d_model, device=device)
self.k_ln = nn.LayerNorm(self.d_model, device=device)
else:
self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
Expand All @@ -154,23 +156,26 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):

def forward(self, x, key_padding_mask, attn_mask=None):
assert attn_mask is None
if self.attn_qk_ln:
qkv = self.W_qkv(x)

# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype)
k = self.k_ln(k).to(dtype)
qkv = torch.cat([q, k, v], dim=-1)
if self.attn_qk_ln or self.clip_qkv:
qkv = self.W_qkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
if self.attn_qk_ln:
# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype)
k = self.k_ln(k).to(dtype)
qkv = torch.cat([q, k, v], dim=-1)

# attention
qkv = rearrange(qkv,
'b s (three h d) -> b s three h d',
three=3,
h=self.n_heads)

context, attn_weights = self.causal_attn(
context, attn_weights = self.inner_attn(
qkv,
key_padding_mask=key_padding_mask,
causal=True,
Expand Down Expand Up @@ -213,22 +218,48 @@ class TritonFlashCausalAttention(nn.Module):
def __init__(self, cfg: DictConfig, device: Optional[str] = None):
super().__init__()
try:
from examples.llm.src.models.layers.flash_attention import \
FlashMHA # type: ignore
from examples.llm.src.models.layers.flash_attention import ( # type: ignore
FlashAttention, FlashMHA)
except ImportError as e:
raise e

assert cfg.attn_pdrop == 0, 'triton kernel does not support attn_dropout'

self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
bias=True,
batch_first=True,
causal=True,
device=device,
)
self.mhsa.out_proj._is_residual = True # type: ignore
self.clip_qkv = cfg.get('attn_clip_qkv')
self.attn_qk_ln = cfg.get('attn_qk_ln')
self.d_model = cfg.d_model
self.n_heads = cfg.n_heads

if self.attn_qk_ln or self.clip_qkv:
self.Wqkv = nn.Linear(self.d_model,
3 * self.d_model,
bias=True,
device=device)
self.inner_attn = FlashAttention(
num_heads=cfg.n_heads,
softmax_scale=cfg.get('softmax_scale'),
device=device)
self.out_proj = nn.Linear(self.d_model,
self.d_model,
bias=True,
device=device)

self.out_proj._is_residual = True # type: ignore

if self.attn_qk_ln:
self.q_ln = nn.LayerNorm(self.d_model, device=device)
self.k_ln = nn.LayerNorm(self.d_model, device=device)
else:
self.mhsa = FlashMHA(
embed_dim=cfg.d_model,
num_heads=cfg.n_heads,
bias=True,
batch_first=True,
causal=True,
softmax_scale=cfg.get('softmax_scale'),
device=device,
)
self.mhsa.out_proj._is_residual = True # type: ignore

warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` '
Expand All @@ -237,11 +268,34 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
'using `attn_impl: flash`.')

def forward(self, x, key_padding_mask=None, attn_mask=None):
assert key_padding_mask is None
return self.mhsa(x,
key_padding_mask=None,
attn_mask=attn_mask,
need_weights=False)
if self.attn_qk_ln or self.clip_qkv:
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
if self.attn_qk_ln:
# Applying layernorm to qk
dtype = qkv.dtype
q, k, v = qkv.split(self.d_model, dim=-1)
q = self.q_ln(q).to(dtype)
k = self.k_ln(k).to(dtype)
qkv = torch.cat([q, k, v], dim=-1)

# attention
context, attn_weights = self.inner_attn(
qkv,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
is_causal=True,
need_weights=False,
average_attn_weights=False)

return self.out_proj(context), attn_weights

else:
return self.mhsa(x,
key_padding_mask=None,
attn_mask=attn_mask,
need_weights=False)

@staticmethod
def mask_shape(n_heads, seq_len, alibi):
Expand Down
3 changes: 2 additions & 1 deletion examples/llm/src/models/layers/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self,
bias=True,
batch_first=True,
causal=False,
softmax_scale=None,
device=None,
dtype=None,
**kwargs) -> None:
Expand All @@ -119,7 +120,7 @@ def __init__(self,
bias=bias,
**factory_kwargs)
self.inner_attn = FlashAttention(num_heads=num_heads,
softmax_scale=None,
softmax_scale=softmax_scale,
**factory_kwargs)
self.out_proj = nn.Linear(embed_dim,
embed_dim,
Expand Down
10 changes: 8 additions & 2 deletions examples/llm/src/models/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ def __init__(self, cfg: DictConfig):
else:
raise ValueError(f'Unknown attn_impl={cfg.attn_impl}')

if cfg.get('attn_qk_ln') and cfg.attn_impl != 'flash':
if cfg.get('attn_qk_ln') and cfg.attn_impl not in ['flash', 'triton']:
raise NotImplementedError(
'LayerNorm over queries and keys in attention is only implemented with flash attention.'
'LayerNorm over queries and keys in attention is only implemented with flash and triton attention.'
)
if cfg.get('attn_clip_qkv') and cfg.attn_impl not in [
'flash', 'triton'
]:
raise NotImplementedError(
'QKV clipping only implemented with flash and triton attention.'
)

self.alibi = cfg.get('alibi', False)
Expand Down
161 changes: 161 additions & 0 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import OmegaConf as om


def allclose_helper(t0, t1, rtol=1e-2, atol=1e-2):
return torch.allclose(t0, t1, rtol=rtol, atol=atol)


@pytest.mark.gpu
def test_flash_torch(device='cuda'):
from examples.llm.src.models.layers.attention import ( # type: ignore
FlashCausalAttention, TorchCausalAttention)

reproducibility.seed_all(7)

cfg = om.create({
'd_model': 256,
'n_heads': 2,
'attn_pdrop': 0,
})

n, s, f = 2, 16, cfg.d_model

fca = FlashCausalAttention(cfg).to(device)
tca = TorchCausalAttention(cfg).to(device)

def gen_tca_mask():
ms = TorchCausalAttention.mask_shape(cfg.n_heads, s, False)
attn_mask = torch.empty(*ms).to(device)
TorchCausalAttention.attn_mask_(attn_mask, cfg.n_heads, s)
return attn_mask

# clone weights
tca.mhsa.in_proj_weight.data = fca.mhsa.Wqkv.weight.data.clone().detach()
tca.mhsa.in_proj_bias.data = fca.mhsa.Wqkv.bias.data.clone().detach()
tca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone(
).detach()
tca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone().detach()

key_padding_mask = torch.ones(n, s).to(device).bool()
x0 = torch.randn(n, s, f).to(device)
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with torch.autocast(x0.device.type):
y0, _ = fca(x0, key_padding_mask, attn_mask=None)
y1, _ = tca(x1, key_padding_mask, attn_mask=gen_tca_mask())
y0 *= key_padding_mask.unsqueeze(-1)
y1 *= key_padding_mask.unsqueeze(-1)

loss0 = y0.sum()
loss1 = y1.sum()

loss0.backward()
loss1.backward()

assert allclose_helper(y0, y1)

assert allclose_helper(tca.mhsa.out_proj.bias.grad,
fca.mhsa.out_proj.bias.grad)
assert allclose_helper(tca.mhsa.out_proj.weight.grad,
fca.mhsa.out_proj.weight.grad)
assert allclose_helper(tca.mhsa.in_proj_bias.grad, fca.mhsa.Wqkv.bias.grad)
assert allclose_helper(tca.mhsa.in_proj_weight.grad,
fca.mhsa.Wqkv.weight.grad)

assert allclose_helper(x0.grad, x1.grad)


@pytest.mark.gpu
@pytest.mark.parametrize('attn_clip_qkv,attn_qk_ln', [
(False, False),
(False, True),
(True, False),
(True, True),
])
def test_flash_triton(attn_clip_qkv, attn_qk_ln, device='cuda'):
from examples.llm.src.models.layers.attention import ( # type: ignore
FlashCausalAttention, TritonFlashCausalAttention)

reproducibility.seed_all(7)

cfg = om.create({
'd_model': 256,
'n_heads': 2,
'attn_pdrop': 0,
'attn_clip_qkv': attn_clip_qkv,
'attn_qk_ln': attn_qk_ln,
})

n, s, f = 2, 16, cfg.d_model

fca = FlashCausalAttention(cfg).to(device)
tfca = TritonFlashCausalAttention(cfg).to(device)
# clone weights
if cfg.attn_qk_ln or cfg.attn_clip_qkv:
tfca.Wqkv.weight.data = fca.W_qkv.weight.data.clone().detach()
tfca.Wqkv.bias.data = fca.W_qkv.bias.data.clone().detach()
tfca.out_proj.weight.data = fca.out_proj.weight.data.clone().detach()
tfca.out_proj.bias.data = fca.out_proj.bias.data.clone().detach()
if cfg.attn_qk_ln:
tfca.q_ln.weight.data = fca.q_ln.weight.data.clone().detach()
tfca.q_ln.bias.data = fca.q_ln.bias.data.clone().detach()
tfca.k_ln.weight.data = fca.k_ln.weight.data.clone().detach()
tfca.k_ln.bias.data = fca.k_ln.bias.data.clone().detach()
else:
tfca.mhsa.Wqkv.weight.data = fca.mhsa.Wqkv.weight.data.clone().detach()
tfca.mhsa.Wqkv.bias.data = fca.mhsa.Wqkv.bias.data.clone().detach()
tfca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone(
).detach()
tfca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone(
).detach()

key_padding_mask = torch.ones(n, s).to(device)
x0 = torch.randn(n, s, f).to(device)
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with torch.autocast(x0.device.type):
y0, _ = fca(x0, key_padding_mask, attn_mask=None)
y1, _ = tfca(x1, key_padding_mask, attn_mask=None)
y0 *= key_padding_mask.unsqueeze(-1)
y1 *= key_padding_mask.unsqueeze(-1)

loss0 = y0.sum()
loss1 = y1.sum()

loss0.backward()
loss1.backward()

assert allclose_helper(y0, y1)

if cfg.attn_qk_ln or cfg.attn_clip_qkv:
assert allclose_helper(tfca.out_proj.bias.grad, fca.out_proj.bias.grad)
assert allclose_helper(tfca.out_proj.weight.grad,
fca.out_proj.weight.grad)
if cfg.attn_qk_ln:
assert allclose_helper(tfca.q_ln.bias.grad, fca.q_ln.bias.grad)
assert allclose_helper(tfca.q_ln.weight.grad, fca.q_ln.weight.grad)
assert allclose_helper(tfca.k_ln.bias.grad, fca.k_ln.bias.grad)
assert allclose_helper(tfca.k_ln.weight.grad, fca.k_ln.weight.grad)
assert allclose_helper(tfca.Wqkv.bias.grad, fca.W_qkv.bias.grad)
assert allclose_helper(tfca.Wqkv.weight.grad, fca.W_qkv.weight.grad)
else:
assert allclose_helper(tfca.mhsa.out_proj.bias.grad,
fca.mhsa.out_proj.bias.grad)
assert allclose_helper(tfca.mhsa.out_proj.weight.grad,
fca.mhsa.out_proj.weight.grad)
assert allclose_helper(tfca.mhsa.Wqkv.bias.grad,
fca.mhsa.Wqkv.bias.grad)
assert allclose_helper(tfca.mhsa.Wqkv.weight.grad,
fca.mhsa.Wqkv.weight.grad)

assert allclose_helper(x0.grad, x1.grad)
2 changes: 1 addition & 1 deletion examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm, batch_size=2):

device = 'cpu'
neo_cfg.device = device
neo_cfg.max_seq_len = 1024
neo_cfg.max_seq_len = 256

if prefixlm:
neo_cfg.model.name = 'hf_prefix_lm'
Expand Down

0 comments on commit e63c50c

Please sign in to comment.