Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 bwd #108

Draft
wants to merge 29 commits into
base: main_perf
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
94b2da3
Alex fp8 work
alexkranias-amd Dec 4, 2024
7434112
fix mismatches
micmelesse Dec 9, 2024
2ea54b1
no navi for now
micmelesse Dec 9, 2024
89d3d7d
fix: ref uses scaling + added ENV VAR to enable/disable quantization …
alexkranias-amd Dec 9, 2024
c65af82
fix: fp8 ref matches kernel
alexkranias-amd Dec 9, 2024
9297d78
misc: added note about p_scale
alexkranias-amd Dec 9, 2024
f92ca5b
feat: added precision error test for various triton ops
alexkranias-amd Dec 9, 2024
9ed1d00
save
alexkranias-amd Dec 4, 2024
1c3f756
feat: added benchmark for fp8 flash attention
alexkranias-amd Dec 10, 2024
c4ca789
fix: quantization scaling in fp8 benchmark
alexkranias-amd Dec 10, 2024
937e814
checkpoint
alexkranias-amd Dec 4, 2024
fd342f7
feat: added fp8 to precision test
alexkranias-amd Dec 11, 2024
543736b
fix: refactor fp32 for torch, moved scaling of fp8 to out of kernel
alexkranias-amd Dec 13, 2024
210e2df
Fix test_op_fwd_prefill
brunomazzottiamd Dec 16, 2024
d8dd966
Document two tests that are failing with FP8
brunomazzottiamd Dec 16, 2024
1835390
Remove cast to fp16, output is already being cast to fp32
brunomazzottiamd Dec 17, 2024
a2624a9
Increase error tolerance for fp8
brunomazzottiamd Dec 17, 2024
8eab5e5
Enable more test cases
brunomazzottiamd Dec 17, 2024
0cf49ce
Fix bug for "bshd" layout
brunomazzottiamd Dec 17, 2024
f413f33
Take max fp8 value into account while computing scales
brunomazzottiamd Dec 18, 2024
4f3e633
Compute 1st FA GEMM without casting to fp16
brunomazzottiamd Dec 18, 2024
6773e3a
Remove redundant `v.to(v.type.element_ty)` cast
brunomazzottiamd Dec 18, 2024
b31cd5d
Fix global scaling for "bhsd" and "bshd" layouts
brunomazzottiamd Dec 18, 2024
3044d7b
[WIP] First attempt to support "thd" layout
brunomazzottiamd Dec 18, 2024
5856c6b
Refactor fp8 scale computation
brunomazzottiamd Dec 23, 2024
a170a08
Compute p scale factor and pass it to the kernel
brunomazzottiamd Dec 23, 2024
85c62ae
Fix minor coding mistakes
brunomazzottiamd Dec 23, 2024
13b07df
Use p scale factor in the kernel
brunomazzottiamd Dec 23, 2024
02a4d8f
Improve scale factor generation
brunomazzottiamd Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[WIP] First attempt to support "thd" layout
  • Loading branch information
brunomazzottiamd committed Dec 18, 2024
commit 3044d7bc22d2f45fdfe8b54d3aee6d32ebfe6fdc
22 changes: 12 additions & 10 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,13 @@ def attn_fwd(Q, K, V, bias, Q_SCALE, K_SCALE, V_SCALE, stride_qscale_z, stride_k


def _scale_fp8(x, x_scale, layout):
assert layout in ["bhsd", "bshd"]
n = x.to(torch.float32)
d = x_scale[:, :, None, None] if layout == "bhsd" else x_scale[:, None, :, None]
if layout == "bhsd":
d = x_scale[:, :, None, None]
elif layout == "bshd":
d = x_scale[:, None, :, None]
elif layout == "thd":
pass # TODO: Implement fp8 scaling fot thd layout.
return (n / d).to(x.dtype)


Expand All @@ -567,12 +571,15 @@ def attention_prefill_forward_triton_impl(
# misc
return_softmax,
use_exp2):

is_varlen = layout == "thd" # check if varlen
scale_per_head = not is_varlen # use False to test global scaling for "bhsd" and "bshd" layouts
is_fp8 = check_is_fp8(q)

if is_fp8:
# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
# TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=scale_per_head, layout=layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)
q = _scale_fp8(q, q_scale, layout)
Expand Down Expand Up @@ -607,11 +614,6 @@ def attention_prefill_forward_triton_impl(
print("return_scores:", return_softmax)
print("use_exp2:", use_exp2)

# import pdb; pdb.set_trace()

# check if varlen
is_varlen = layout == "thd"

# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (bias is not None):
assert (bias.numel() < 2**31)
Expand Down
7 changes: 5 additions & 2 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,14 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali
)
@pytest.mark.parametrize('causal', [False]) # FIXME: There are some mismatches for causal.
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('layout', ["bhsd", "bshd"])
@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"])
@pytest.mark.parametrize('use_exp2', [True, False])
@pytest.mark.parametrize('dtype', [torch.float16, torch.float8_e4m3fnuz])
@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues.
def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, use_exp2, dtype, DEBUG_INPUT):
if layout == "thd" and dtype == torch.float8_e4m3fnuz:
pytest.skip("fp8 support for thd layout is under development.")

if dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]:
atol = 1.009e-01
rtol = 9.128e-02
Expand Down Expand Up @@ -459,7 +462,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou

if DEBUG:
print()
print("Compare Triton Impl with refernce Pytorch Impl")
print("Compare Triton Impl with reference Pytorch Impl")

# this can be set to true manually or when using dropout
if metadata.return_scores:
Expand Down
30 changes: 20 additions & 10 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,22 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda

if DEBUG_INPUT:
# Initialize q, k, v with deterministic values
q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
q = torch.arange(total_q, dtype=torch.float32, device=device).view(total_q, 1, 1)
q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
k = torch.arange(total_k, dtype=torch.float32, device=device).view(total_k, 1, 1)
k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
v = torch.arange(total_k, dtype=torch.float32, device=device).view(total_k, 1, 1)
v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
sm_scale = 1
else:
# Initialize q, k, v with random values
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
q = torch.randn((total_q, HQ, D_HEAD), dtype=torch.float32, device=device).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=torch.float32, device=device).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=torch.float32, device=device).requires_grad_()
sm_scale = D_HEAD ** -0.5

q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)

input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
return q, k, v, input_metadata
Expand Down Expand Up @@ -341,6 +343,7 @@ def is_cdna():
def is_rdna():
return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201")


def check_is_fp8(x: torch.Tensor):
if REMOVE_QUANTIZATION_SCALING:
return False # makes all methods believe they aren't working with fp8s, so no scaling is applied
Expand All @@ -353,7 +356,8 @@ def check_is_fp8(x: torch.Tensor):
}
return x.dtype in fp8_types

def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd'):

def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd', cu_seqlens_q=None, cu_seqlens_k=None):
"""
Create scale tensors for q and k based on the scaling configuration.

Expand Down Expand Up @@ -386,9 +390,10 @@ def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd'):
q_float32 = q.to(torch.float32)
k_float32 = k.to(torch.float32)
v_float32 = v.to(torch.float32)

if SCALE_PER_HEAD:
if is_varlen:
# FIXME: varlen should be supported.
assert False, "VARLEN NOT SUPPORTED FOR SCALE PER HEAD"
else:
# Compute max for each batch-head pair.
Expand All @@ -410,8 +415,12 @@ def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd'):
batch_q, head_q, _, _ = q.shape
batch_k, head_k, _, _ = k.shape
elif layout == "thd":
# FIXME: varlen not working! ValueError: not enough values to unpack (expected 4, got 3)
pass
assert cu_seqlens_q is not None
batch_q = len(cu_seqlens_q) - 1
head_q = q.shape[1]
assert cu_seqlens_k is not None
batch_k = len(cu_seqlens_k) - 1
head_k = k.shape[1]
assert batch_q == batch_k
q_scale = torch.full((batch_q, head_q), q_global_max, device=q.device)
k_scale = torch.full((batch_k, head_k), k_global_max, device=k.device)
Expand All @@ -437,6 +446,7 @@ def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd'):
elif layout == 'bhsd':
batch, head, _, _ = q.shape
else:
# FIXME: varlen should be supported.
assert False, "VARLEN NOT SUPPORTED"
q_scale = torch.ones((batch, head), device=q.device)
k_scale = torch.ones((batch, head), device=k.device)
Expand Down