Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add softmax_d for mha_bwd #1161

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 3 additions & 80 deletions benchmarks/benchmark_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,83 +143,6 @@ def attention_megatron(qkv):
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
# pytorch_profiler(torch.fft.rfft, u.float())

flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
ideal_a100_time = flops / 312 / 1e9
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
exit(0)


def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean

bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
headdim_vals = [64, 128]
dim = 2048
dropout_p = 0.0

time_f = {}
time_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
f, b = time_fwd_bwd(
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b

qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b

# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# # Try both values of sequence_parallel and pick the faster one
# f, b = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# False, repeats=repeats, verbose=False
# )
# _, b0 = time_fwd_bwd(
# attention_triton, q, k, v, causal, headdim**(-0.5),
# True, repeats=repeats, verbose=False
# )
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)

if seqlen <= 8 * 1024:
qkv = qkv.detach().requires_grad_(True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
else:
f, b = float('nan'), float('nan')
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b

# q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
# requires_grad=True) for _ in range(3)]
# import xformers.ops as xops
# f, b = time_fwd_bwd(
# xops.memory_efficient_attention, q, k, v,
# attn_bias=xops.LowerTriangularMask() if causal else None,
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
# )
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b


import pickle
with open('flash2_attn_time_h100.plk', 'wb') as fp:
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
# flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
# ideal_a100_time = flops / 312 / 1e9
# print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
19 changes: 11 additions & 8 deletions benchmarks/benchmark_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined

from flash_attn import flash_attn_qkvpacked_func
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

try:
from triton.ops.flash_attention import attention as attention_triton
except ImportError:
attention_triton = None
attention_triton = None

try:
import xformers.ops as xops
Expand Down Expand Up @@ -71,13 +73,13 @@ def time_fwd_bwd(func, *args, **kwargs):
device = 'cuda'
dtype = torch.float16

bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
bs_seqlen_vals = [(1, 65536), (1, 131072)]
causal_vals = [False, True]
headdim_vals = [64, 128]
headdim_vals = [64]
dim = 2048
dropout_p = 0.0

methods = (["Flash2", "Pytorch"]
methods = (["Flash2", ]
+ (["Triton"] if attention_triton is not None else [])
+ (["xformers.c"] if xops is not None else [])
+ (["xformers.f"] if xops is not None else []))
Expand All @@ -95,8 +97,9 @@ def time_fwd_bwd(func, *args, **kwargs):
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
q, k, v = qkv.flatten(1, 2).chunk(3, dim=1)
f, b = time_fwd_bwd(
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=True
)
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b
Expand Down Expand Up @@ -170,9 +173,9 @@ def time_fwd_bwd(func, *args, **kwargs):
time_f_b[config, method]
)
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method]} s/fwd, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, {time_b[config, method]} s/bwd, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s, {time_f_b[config, method]} s/fwd_bwd"
)


Expand Down
23 changes: 19 additions & 4 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &softmax_d_, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
Expand Down Expand Up @@ -789,7 +790,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
Expand All @@ -798,7 +798,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");

const auto sizes = q.sizes();
Expand All @@ -818,6 +817,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (!softmax_d_.has_value()) {
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
}

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
Expand All @@ -831,7 +836,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);

at::Tensor dq, dk, dv;
Expand Down Expand Up @@ -879,7 +883,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
bool has_softmax_d = softmax_d_.has_value();
at::Tensor softmax_d;
if (! has_softmax_d){
softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
} else{
softmax_d = softmax_d_.value();
TORCH_CHECK(softmax_d.dtype() == torch::kFloat32, "softmax_d must have dtype float32");
CHECK_DEVICE(softmax_d);
TORCH_CHECK(softmax_d.stride(-1) == 1, "softmax_d must have contiguous last dimension");
CHECK_SHAPE(softmax_d, batch_size, num_heads, seqlen_q_rounded);
}
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
Expand Down Expand Up @@ -926,6 +940,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
window_size_left,
window_size_right,
deterministic);
params.has_softmax_d = has_softmax_d;
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);

auto launch = &run_mha_bwd;
Expand Down
1 change: 1 addition & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ struct Flash_fwd_params : public Qkv_params {

bool is_bf16;
bool is_causal;
bool has_softmax_d;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
Expand Down
43 changes: 23 additions & 20 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
// Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
// Shape<Int<kBlockM>, Int<kHeadDim>>{},
// make_stride(params.o_row_stride, _1{}));

Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Expand Down Expand Up @@ -197,7 +198,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
// Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Expand Down Expand Up @@ -379,7 +380,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}

Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO);
// Tensor tdOrO = make_fragment_like(tdOgO);
if (!Is_first) {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
Expand All @@ -389,9 +390,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
// flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
// gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
// );
}
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
Expand Down Expand Up @@ -429,11 +430,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
flash::cp_async_fence();

// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if (Is_first) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
}

// if (Is_first) {

// cute::copy(tdOrdO, tdOsdO);
// dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
// Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
// }

if (Kernel_traits::Is_V_in_regs) {
cute::cp_async_wait<1>();
Expand Down Expand Up @@ -628,9 +631,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Advance gdO
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
if (Is_first) {
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
// tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
// flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
} else {
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
flash::cp_async_fence();
Expand Down Expand Up @@ -685,11 +688,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
flash::cp_async_fence();
}

if (Is_first && m_block > m_block_min) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
}
// if (Is_first && m_block > m_block_min) {
// cute::copy(tdOrdO, tdOsdO);
// dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
// Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
// }

if (Is_last) {
__syncthreads();
Expand Down
21 changes: 17 additions & 4 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params);
}

template<typename Kernel_traits>
__global__ void flash_bwd_clear_dqaccum_kernel(const Flash_bwd_params params) {
flash::clear_dQaccum<Kernel_traits>(params);
}

template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits);
Expand All @@ -76,12 +81,20 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);

if (!params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
if (! params.has_softmax_d){
if (! params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
}
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
if (! params.deterministic) {
// do atomicAdds on.
flash_bwd_clear_dqaccum_kernel<Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);

}
}

C10_CUDA_KERNEL_LAUNCH_CHECK();

// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
Expand Down
Loading