Skip to content

Commit

Permalink
make test smaller
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Mar 1, 2023
1 parent 8d2280a commit 2e2d298
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
13 changes: 8 additions & 5 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from composer.utils import reproducibility
from omegaconf import OmegaConf as om

from examples.llm.src.models.layers.attention import ( # type: ignore
FlashCausalAttention, TorchCausalAttention, TritonFlashCausalAttention)


def allclose_helper(t0, t1, rtol=1e-2, atol=1e-2):
return torch.allclose(t0, t1, rtol=rtol, atol=atol)


# @pytest.mark.gpu
@pytest.mark.gpu
def test_flash_torch(device='cuda'):
from examples.llm.src.models.layers.attention import ( # type: ignore
FlashCausalAttention, TorchCausalAttention)

reproducibility.seed_all(7)

cfg = om.create({
Expand Down Expand Up @@ -73,14 +73,17 @@ def gen_tca_mask():
assert allclose_helper(x0.grad, x1.grad)


# @pytest.mark.gpu
@pytest.mark.gpu
@pytest.mark.parametrize('attn_clip_qkv,attn_qk_ln', [
(False, False),
(False, True),
(True, False),
(True, True),
])
def test_flash_triton(attn_clip_qkv, attn_qk_ln, device='cuda'):
from examples.llm.src.models.layers.attention import ( # type: ignore
FlashCausalAttention, TritonFlashCausalAttention)

reproducibility.seed_all(7)

cfg = om.create({
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm, batch_size=2):

device = 'cpu'
neo_cfg.device = device
neo_cfg.max_seq_len = 1024
neo_cfg.max_seq_len = 256

if prefixlm:
neo_cfg.model.name = 'hf_prefix_lm'
Expand Down

0 comments on commit 2e2d298

Please sign in to comment.