Skip to content

Commit

Permalink
Fix causal flag on softmax Triton on CPU (facebookresearch#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa authored Feb 8, 2022
1 parent 143318f commit 31c2292
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
16 changes: 10 additions & 6 deletions tests/test_triton_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,15 @@ def test_softmax_fp16(dtype):
@pytest.mark.parametrize("masking", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_softmax_parity_fallback(log, masking, causal, contiguous):
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_softmax_parity_fallback(log, masking, causal, contiguous, device):
"""Check that the fallback paths are correct"""
torch.random.manual_seed(0)

shape = (16, 16)

# Check the result of a FW pass
X = torch.normal(0, 1, size=shape, device="cpu", requires_grad=False)
X = torch.normal(0, 1, size=shape, device=device, requires_grad=False)

if not contiguous:
# Make sure that the buffer is not contiguous
Expand All @@ -125,15 +126,18 @@ def test_softmax_parity_fallback(log, masking, causal, contiguous):
X_.requires_grad = True

seq = shape[1]
mask = torch.zeros((seq, seq))
mask = torch.zeros((seq, seq), device=device)
if masking:
mask[torch.rand((seq, seq)) > 0.8] = -float("inf")
mask[torch.rand((seq, seq), device=device) > 0.8] = -float("inf")

mask_causal = torch.zeros_like(mask)
if causal:
mask[~torch.tril(torch.ones_like(mask)).bool()] = -float("inf")
mask_causal[~torch.tril(torch.ones_like(mask)).bool()] = -float("inf")

y_torch = (
torch.log_softmax(X + mask, dim=-1) if log else torch.softmax(X + mask, dim=-1)
torch.log_softmax(X + mask + mask_causal, dim=-1)
if log
else torch.softmax(X + mask + mask_causal, dim=-1)
)
y_triton = (
triton_log_softmax(X_, mask, causal)
Expand Down
12 changes: 3 additions & 9 deletions xformers/triton/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


_triton_registered_overflow = False
_triton_registered_warnings = False
_triton_softmax_fp16_enabled = False # NOTE: PyTorch keeps softmax as fp32


Expand Down Expand Up @@ -179,7 +178,6 @@ def _softmax_dispatch(
# - there was no previous failure

global _triton_registered_overflow
global _triton_registered_warnings

try:
if torch.cuda.is_available() and x.is_cuda and not _triton_registered_overflow:
Expand All @@ -194,16 +192,12 @@ def _softmax_dispatch(
)
logging.warning(e)

if causal and not _triton_registered_warnings:
logging.warning(
"Triton softmax could not be used. \
The causal flags is being passed but it does not provide any benefit with PyTorch softmax."
)
_triton_registered_warnings = True

if mask is not None:
x = x + mask

if causal:
x = x + torch.triu(torch.full_like(x, float("-inf")), diagonal=1)

if log:
return torch.log_softmax(x, dim=-1)
else:
Expand Down

0 comments on commit 31c2292

Please sign in to comment.