Skip to content

Commit

Permalink
updt
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Mar 1, 2023
1 parent eeaddd1 commit 212b692
Showing 1 changed file with 65 additions and 67 deletions.
132 changes: 65 additions & 67 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from composer.utils import reproducibility
from omegaconf import OmegaConf
from omegaconf import OmegaConf as om

from examples.llm.src.models.layers.attention import (
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention)
Expand All @@ -13,6 +13,69 @@
ATOL = 1e-2


def test_flash_torch():
reproducibility.seed_all(7)

cfg = om.create({
'd_model': 256,
'n_heads': 2,
'attn_pdrop': 0,
})

n, s, f = 2, 16, cfg.d_model

fca = FlashCausalAttention(cfg).to('cuda')
tca = TorchCausalAttention(cfg).to('cuda')

def gen_tca_mask():
ms = TorchCausalAttention.mask_shape(cfg.n_heads, s, False)
attn_mask = torch.empty(*ms).to('cuda')
TorchCausalAttention.attn_mask_(attn_mask, cfg.n_heads, s)
return attn_mask

# clone weights
tca.mhsa.in_proj_weight.data = fca.mhsa.Wqkv.weight.data.clone().detach()
tca.mhsa.in_proj_bias.data = fca.mhsa.Wqkv.bias.data.clone().detach()
tca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone(
).detach()
tca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone().detach()

key_padding_mask = torch.ones(n, s).to('cuda').bool()
x0 = torch.randn(n, s, f).to('cuda')
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with torch.autocast(x0.device.type):
y0, _ = fca(x0, key_padding_mask, attn_mask=None)
y1, _ = tca(x1, key_padding_mask, attn_mask=gen_tca_mask())
y0 *= key_padding_mask.unsqueeze(-1)
y1 *= key_padding_mask.unsqueeze(-1)

loss0 = y0.sum()
loss1 = y1.sum()

loss0.backward()
loss1.backward()

assert y0.allclose(y1, rtol=RTOL, atol=ATOL)

assert tca.mhsa.out_proj.bias.grad.allclose(fca.mhsa.out_proj.bias.grad,
rtol=RTOL,
atol=ATOL)
assert tca.mhsa.out_proj.weight.grad.allclose(fca.mhsa.out_proj.weight.grad,
rtol=RTOL,
atol=ATOL)
assert tca.mhsa.in_proj_bias.grad.allclose(fca.mhsa.Wqkv.bias.grad,
rtol=RTOL,
atol=ATOL)
assert tca.mhsa.in_proj_weight.grad.allclose(fca.mhsa.Wqkv.weight.grad,
rtol=RTOL,
atol=ATOL)

assert x0.grad.allclose(x1.grad, rtol=RTOL, atol=ATOL)


@pytest.mark.parametrize('attn_clip_qkv,attn_qk_ln', [
(False, False),
(False, True),
Expand All @@ -22,7 +85,7 @@
def test_flash_triton(attn_clip_qkv, attn_qk_ln):
reproducibility.seed_all(7)

cfg = OmegaConf.create({
cfg = om.create({
'd_model': 256,
'n_heads': 2,
'attn_pdrop': 0,
Expand Down Expand Up @@ -112,68 +175,3 @@ def test_flash_triton(attn_clip_qkv, attn_qk_ln):
atol=ATOL)

assert x0.grad.allclose(x1.grad, rtol=RTOL, atol=ATOL)


def test_flash_torch():
reproducibility.seed_all(7)

cfg = OmegaConf.create({
'd_model': 256,
'n_heads': 2,
'attn_pdrop': 0,
'attn_clip_qkv': False,
'attn_qk_ln': False,
})

n, s, f = 2, 16, cfg.d_model

fca = FlashCausalAttention(cfg).to('cuda')
tca = TorchCausalAttention(cfg).to('cuda')

def gen_tca_mask():
ms = TorchCausalAttention.mask_shape(cfg.n_heads, s, False)
attn_mask = torch.empty(*ms).to('cuda')
TorchCausalAttention.attn_mask_(attn_mask, cfg.n_heads, s)
return attn_mask

# clone weights
tca.mhsa.in_proj_weight.data = fca.mhsa.Wqkv.weight.data.clone().detach()
tca.mhsa.in_proj_bias.data = fca.mhsa.Wqkv.bias.data.clone().detach()
tca.mhsa.out_proj.weight.data = fca.mhsa.out_proj.weight.data.clone(
).detach()
tca.mhsa.out_proj.bias.data = fca.mhsa.out_proj.bias.data.clone().detach()

key_padding_mask = torch.ones(n, s).to('cuda').bool()
x0 = torch.randn(n, s, f).to('cuda')
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with torch.autocast(x0.device.type):
y0, _ = fca(x0, key_padding_mask, attn_mask=None)
y1, _ = tca(x1, key_padding_mask, attn_mask=gen_tca_mask())
y0 *= key_padding_mask.unsqueeze(-1)
y1 *= key_padding_mask.unsqueeze(-1)

loss0 = y0.sum()
loss1 = y1.sum()

loss0.backward()
loss1.backward()

assert y0.allclose(y1, rtol=RTOL, atol=ATOL)

assert tca.mhsa.out_proj.bias.grad.allclose(fca.mhsa.out_proj.bias.grad,
rtol=RTOL,
atol=ATOL)
assert tca.mhsa.out_proj.weight.grad.allclose(fca.mhsa.out_proj.weight.grad,
rtol=RTOL,
atol=ATOL)
assert tca.mhsa.in_proj_bias.grad.allclose(fca.mhsa.Wqkv.bias.grad,
rtol=RTOL,
atol=ATOL)
assert tca.mhsa.in_proj_weight.grad.allclose(fca.mhsa.Wqkv.weight.grad,
rtol=RTOL,
atol=ATOL)

assert x0.grad.allclose(x1.grad, rtol=RTOL, atol=ATOL)

0 comments on commit 212b692

Please sign in to comment.