diff --git a/examples/llm/src/models/layers/attention.py b/examples/llm/src/models/layers/attention.py index 6f86117ef..a21136139 100644 --- a/examples/llm/src/models/layers/attention.py +++ b/examples/llm/src/models/layers/attention.py @@ -120,17 +120,18 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None): except ImportError as e: raise e + self.clip_qkv = cfg.get('attn_clip_qkv') self.attn_qk_ln = cfg.get('attn_qk_ln') self.d_model = cfg.d_model self.n_heads = cfg.n_heads - if self.attn_qk_ln: + if self.attn_qk_ln or self.clip_qkv: self.W_qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=True, device=device) - self.causal_attn = FlashAttention(attention_dropout=cfg.attn_pdrop, - device=device) + self.inner_attn = FlashAttention(attention_dropout=cfg.attn_pdrop, + device=device) self.out_proj = nn.Linear(self.d_model, self.d_model, bias=True, @@ -138,8 +139,9 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None): self.out_proj._is_residual = True # type: ignore - self.q_ln = nn.LayerNorm(self.d_model, device=device) - self.k_ln = nn.LayerNorm(self.d_model, device=device) + if self.attn_qk_ln: + self.q_ln = nn.LayerNorm(self.d_model, device=device) + self.k_ln = nn.LayerNorm(self.d_model, device=device) else: self.mhsa = FlashMHA( embed_dim=cfg.d_model, @@ -154,15 +156,18 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None): def forward(self, x, key_padding_mask, attn_mask=None): assert attn_mask is None - if self.attn_qk_ln: - qkv = self.W_qkv(x) - # Applying layernorm to qk - dtype = qkv.dtype - q, k, v = qkv.split(self.d_model, dim=-1) - q = self.q_ln(q).to(dtype) - k = self.k_ln(k).to(dtype) - qkv = torch.cat([q, k, v], dim=-1) + if self.attn_qk_ln or self.clip_qkv: + qkv = self.W_qkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + if self.attn_qk_ln: + # Applying layernorm to qk + dtype = qkv.dtype + q, k, v = qkv.split(self.d_model, dim=-1) + q = self.q_ln(q).to(dtype) + k = self.k_ln(k).to(dtype) + qkv = torch.cat([q, k, v], dim=-1) # attention qkv = rearrange(qkv, @@ -170,7 +175,7 @@ def forward(self, x, key_padding_mask, attn_mask=None): three=3, h=self.n_heads) - context, attn_weights = self.causal_attn( + context, attn_weights = self.inner_attn( qkv, key_padding_mask=key_padding_mask, causal=True, @@ -213,22 +218,48 @@ class TritonFlashCausalAttention(nn.Module): def __init__(self, cfg: DictConfig, device: Optional[str] = None): super().__init__() try: - from examples.llm.src.models.layers.flash_attention import \ - FlashMHA # type: ignore + from examples.llm.src.models.layers.flash_attention import ( # type: ignore + FlashAttention, FlashMHA) except ImportError as e: raise e assert cfg.attn_pdrop == 0, 'triton kernel does not support attn_dropout' - self.mhsa = FlashMHA( - embed_dim=cfg.d_model, - num_heads=cfg.n_heads, - bias=True, - batch_first=True, - causal=True, - device=device, - ) - self.mhsa.out_proj._is_residual = True # type: ignore + self.clip_qkv = cfg.get('attn_clip_qkv') + self.attn_qk_ln = cfg.get('attn_qk_ln') + self.d_model = cfg.d_model + self.n_heads = cfg.n_heads + + if self.attn_qk_ln or self.clip_qkv: + self.Wqkv = nn.Linear(self.d_model, + 3 * self.d_model, + bias=True, + device=device) + self.inner_attn = FlashAttention( + num_heads=cfg.n_heads, + softmax_scale=cfg.get('softmax_scale'), + device=device) + self.out_proj = nn.Linear(self.d_model, + self.d_model, + bias=True, + device=device) + + self.out_proj._is_residual = True # type: ignore + + if self.attn_qk_ln: + self.q_ln = nn.LayerNorm(self.d_model, device=device) + self.k_ln = nn.LayerNorm(self.d_model, device=device) + else: + self.mhsa = FlashMHA( + embed_dim=cfg.d_model, + num_heads=cfg.n_heads, + bias=True, + batch_first=True, + causal=True, + softmax_scale=cfg.get('softmax_scale'), + device=device, + ) + self.mhsa.out_proj._is_residual = True # type: ignore warnings.warn( 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' @@ -237,11 +268,34 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None): 'using `attn_impl: flash`.') def forward(self, x, key_padding_mask=None, attn_mask=None): - assert key_padding_mask is None - return self.mhsa(x, - key_padding_mask=None, - attn_mask=attn_mask, - need_weights=False) + if self.attn_qk_ln or self.clip_qkv: + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + if self.attn_qk_ln: + # Applying layernorm to qk + dtype = qkv.dtype + q, k, v = qkv.split(self.d_model, dim=-1) + q = self.q_ln(q).to(dtype) + k = self.k_ln(k).to(dtype) + qkv = torch.cat([q, k, v], dim=-1) + + # attention + context, attn_weights = self.inner_attn( + qkv, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + is_causal=True, + need_weights=False, + average_attn_weights=False) + + return self.out_proj(context), attn_weights + + else: + return self.mhsa(x, + key_padding_mask=None, + attn_mask=attn_mask, + need_weights=False) @staticmethod def mask_shape(n_heads, seq_len, alibi): diff --git a/examples/llm/src/models/layers/flash_attention.py b/examples/llm/src/models/layers/flash_attention.py index 1a0931e6f..a3b010d8c 100644 --- a/examples/llm/src/models/layers/flash_attention.py +++ b/examples/llm/src/models/layers/flash_attention.py @@ -100,6 +100,7 @@ def __init__(self, bias=True, batch_first=True, causal=False, + softmax_scale=None, device=None, dtype=None, **kwargs) -> None: @@ -119,7 +120,7 @@ def __init__(self, bias=bias, **factory_kwargs) self.inner_attn = FlashAttention(num_heads=num_heads, - softmax_scale=None, + softmax_scale=softmax_scale, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, diff --git a/examples/llm/src/models/mosaic_gpt.py b/examples/llm/src/models/mosaic_gpt.py index cd64ebbf7..2fb21a225 100644 --- a/examples/llm/src/models/mosaic_gpt.py +++ b/examples/llm/src/models/mosaic_gpt.py @@ -37,9 +37,15 @@ def __init__(self, cfg: DictConfig): else: raise ValueError(f'Unknown attn_impl={cfg.attn_impl}') - if cfg.get('attn_qk_ln') and cfg.attn_impl != 'flash': + if cfg.get('attn_qk_ln') and cfg.attn_impl not in ['flash', 'triton']: raise NotImplementedError( - 'LayerNorm over queries and keys in attention is only implemented with flash attention.' + 'LayerNorm over queries and keys in attention is only implemented with flash and triton attention.' + ) + if cfg.get('attn_clip_qkv') and cfg.attn_impl not in [ + 'flash', 'triton' + ]: + raise NotImplementedError( + 'QKV clipping only implemented with flash and triton attention.' ) self.alibi = cfg.get('alibi', False) diff --git a/examples/llm/tests/test_flash_triton_torch.py b/examples/llm/tests/test_flash_triton_torch.py new file mode 100644 index 000000000..eb7c44098 --- /dev/null +++ b/examples/llm/tests/test_flash_triton_torch.py @@ -0,0 +1,161 @@ +# Copyright 2022 MosaicML Examples authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from composer.utils import reproducibility +from omegaconf import OmegaConf as om + + +def allclose_helper(t0, t1, rtol=1e-2, atol=1e-2): + return torch.allclose(t0, t1, rtol=rtol, atol=atol) + + +@pytest.mark.gpu +def test_flash_torch(device='cuda'): + from examples.llm.src.models.layers.attention import ( # type: ignore + FlashCausalAttention, TorchCausalAttention) + + reproducibility.seed_all(7) + + cfg = om.create({ + 'd_model': 256, + 'n_heads': 2, + 'attn_pdrop': 0, + }) + + n, s, f = 2, 16, cfg.d_model + + fca = FlashCausalAttention(cfg).to(device) + tca = TorchCausalAttention(cfg).to(device) + + def gen_tca_mask(): + ms = TorchCausalAttention.mask_shape(cfg.n_heads, s, False) + attn_mask = torch.empty(*ms).to(device) + TorchCausalAttention.attn_mask_(attn_mask, cfg.n_heads, s) + return attn_mask + + # clone weights + tca.mhsa.in_proj_weight.data = fca.mhsa.Wqkv.weight.data.clone().detach() + tca.mhsa.in_proj_bias.data = fca.mhsa.Wqkv.bias.data.clone().detach() + tca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone( + ).detach() + tca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone().detach() + + key_padding_mask = torch.ones(n, s).to(device).bool() + x0 = torch.randn(n, s, f).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + + with torch.autocast(x0.device.type): + y0, _ = fca(x0, key_padding_mask, attn_mask=None) + y1, _ = tca(x1, key_padding_mask, attn_mask=gen_tca_mask()) + y0 *= key_padding_mask.unsqueeze(-1) + y1 *= key_padding_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + + assert allclose_helper(y0, y1) + + assert allclose_helper(tca.mhsa.out_proj.bias.grad, + fca.mhsa.out_proj.bias.grad) + assert allclose_helper(tca.mhsa.out_proj.weight.grad, + fca.mhsa.out_proj.weight.grad) + assert allclose_helper(tca.mhsa.in_proj_bias.grad, fca.mhsa.Wqkv.bias.grad) + assert allclose_helper(tca.mhsa.in_proj_weight.grad, + fca.mhsa.Wqkv.weight.grad) + + assert allclose_helper(x0.grad, x1.grad) + + +@pytest.mark.gpu +@pytest.mark.parametrize('attn_clip_qkv,attn_qk_ln', [ + (False, False), + (False, True), + (True, False), + (True, True), +]) +def test_flash_triton(attn_clip_qkv, attn_qk_ln, device='cuda'): + from examples.llm.src.models.layers.attention import ( # type: ignore + FlashCausalAttention, TritonFlashCausalAttention) + + reproducibility.seed_all(7) + + cfg = om.create({ + 'd_model': 256, + 'n_heads': 2, + 'attn_pdrop': 0, + 'attn_clip_qkv': attn_clip_qkv, + 'attn_qk_ln': attn_qk_ln, + }) + + n, s, f = 2, 16, cfg.d_model + + fca = FlashCausalAttention(cfg).to(device) + tfca = TritonFlashCausalAttention(cfg).to(device) + # clone weights + if cfg.attn_qk_ln or cfg.attn_clip_qkv: + tfca.Wqkv.weight.data = fca.W_qkv.weight.data.clone().detach() + tfca.Wqkv.bias.data = fca.W_qkv.bias.data.clone().detach() + tfca.out_proj.weight.data = fca.out_proj.weight.data.clone().detach() + tfca.out_proj.bias.data = fca.out_proj.bias.data.clone().detach() + if cfg.attn_qk_ln: + tfca.q_ln.weight.data = fca.q_ln.weight.data.clone().detach() + tfca.q_ln.bias.data = fca.q_ln.bias.data.clone().detach() + tfca.k_ln.weight.data = fca.k_ln.weight.data.clone().detach() + tfca.k_ln.bias.data = fca.k_ln.bias.data.clone().detach() + else: + tfca.mhsa.Wqkv.weight.data = fca.mhsa.Wqkv.weight.data.clone().detach() + tfca.mhsa.Wqkv.bias.data = fca.mhsa.Wqkv.bias.data.clone().detach() + tfca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone( + ).detach() + tfca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone( + ).detach() + + key_padding_mask = torch.ones(n, s).to(device) + x0 = torch.randn(n, s, f).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + + with torch.autocast(x0.device.type): + y0, _ = fca(x0, key_padding_mask, attn_mask=None) + y1, _ = tfca(x1, key_padding_mask, attn_mask=None) + y0 *= key_padding_mask.unsqueeze(-1) + y1 *= key_padding_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + + assert allclose_helper(y0, y1) + + if cfg.attn_qk_ln or cfg.attn_clip_qkv: + assert allclose_helper(tfca.out_proj.bias.grad, fca.out_proj.bias.grad) + assert allclose_helper(tfca.out_proj.weight.grad, + fca.out_proj.weight.grad) + if cfg.attn_qk_ln: + assert allclose_helper(tfca.q_ln.bias.grad, fca.q_ln.bias.grad) + assert allclose_helper(tfca.q_ln.weight.grad, fca.q_ln.weight.grad) + assert allclose_helper(tfca.k_ln.bias.grad, fca.k_ln.bias.grad) + assert allclose_helper(tfca.k_ln.weight.grad, fca.k_ln.weight.grad) + assert allclose_helper(tfca.Wqkv.bias.grad, fca.W_qkv.bias.grad) + assert allclose_helper(tfca.Wqkv.weight.grad, fca.W_qkv.weight.grad) + else: + assert allclose_helper(tfca.mhsa.out_proj.bias.grad, + fca.mhsa.out_proj.bias.grad) + assert allclose_helper(tfca.mhsa.out_proj.weight.grad, + fca.mhsa.out_proj.weight.grad) + assert allclose_helper(tfca.mhsa.Wqkv.bias.grad, + fca.mhsa.Wqkv.bias.grad) + assert allclose_helper(tfca.mhsa.Wqkv.weight.grad, + fca.mhsa.Wqkv.weight.grad) + + assert allclose_helper(x0.grad, x1.grad) diff --git a/examples/llm/tests/test_model.py b/examples/llm/tests/test_model.py index 0ef59c779..248b2a93b 100644 --- a/examples/llm/tests/test_model.py +++ b/examples/llm/tests/test_model.py @@ -186,7 +186,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm, batch_size=2): device = 'cpu' neo_cfg.device = device - neo_cfg.max_seq_len = 1024 + neo_cfg.max_seq_len = 256 if prefixlm: neo_cfg.model.name = 'hf_prefix_lm'