Skip to content

Commit 1479d3d

Browse files
authored
Enabled flash attn varlen for chunked prefilll (#3148)
1 parent 5e7bb51 commit 1479d3d

File tree

9 files changed

+822
-99
lines changed

9 files changed

+822
-99
lines changed

csrc/cpu/aten/PagedAttention.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace cpu {
77

88
IPEX_DEFINE_DISPATCH(single_query_cached_kv_attention_kernel_stub);
99
IPEX_DEFINE_DISPATCH(reshape_and_cache_kernel_stub);
10+
IPEX_DEFINE_DISPATCH(flash_attn_var_len_kernel_stub);
1011

1112
/*
1213
*Caculate the masked multihead attention for decoder layer in decoder only
@@ -48,6 +49,35 @@ void reshape_and_cache_cpu(
4849
kCPU, key, value, key_cache, value_cache, slot_mapping);
4950
}
5051

52+
void flash_attn_varlen_cpu(
53+
at::Tensor& out,
54+
at::Tensor& query,
55+
at::Tensor& key,
56+
at::Tensor& value,
57+
at::Tensor& cu_seqlens_q,
58+
at::Tensor& cu_seqlens_kv,
59+
int64_t max_seqlen_q,
60+
int64_t max_seqlen_kv,
61+
const double softmax_scale,
62+
bool is_causal,
63+
at::Tensor& block_table,
64+
const c10::optional<at::Tensor>& alibi_slopes) {
65+
return flash_attn_var_len_kernel_stub(
66+
kCPU,
67+
out,
68+
query,
69+
key,
70+
value,
71+
cu_seqlens_q,
72+
cu_seqlens_kv,
73+
max_seqlen_q,
74+
max_seqlen_kv,
75+
softmax_scale,
76+
is_causal,
77+
block_table,
78+
alibi_slopes);
79+
}
80+
5181
} // namespace cpu
5282
} // namespace torch_ipex
5383

@@ -68,5 +98,14 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
6898
"reshape_and_cache",
6999
c10::DispatchKey::CPU,
70100
torch_ipex::cpu::reshape_and_cache_cpu);
101+
m.def(
102+
"flash_attn_varlen_func(Tensor (a!)out, Tensor (a!)query, Tensor (a!)key, Tensor (a!)value, Tensor(a!) cu_seqlens_q,\
103+
Tensor(a!) cu_seqlens_kv, int max_seqlen_q, int max_seqlen_kv, float softmax_scale, bool is_causal, Tensor(a!) block_table, \
104+
Tensor? alibi_slopes)-> ()");
105+
106+
m.impl(
107+
"flash_attn_varlen_func",
108+
c10::DispatchKey::CPU,
109+
torch_ipex::cpu::flash_attn_varlen_cpu);
71110
}
72111
} // namespace

csrc/cpu/aten/PagedAttention.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ void reshape_and_cache(
2929
at::Tensor& value_cache,
3030
at::Tensor& slot_mapping);
3131

32+
void flash_attn_varlen(
33+
at::Tensor& out,
34+
at::Tensor& query,
35+
at::Tensor& key,
36+
at::Tensor& value,
37+
at::Tensor& cu_seqlens_q,
38+
at::Tensor& cu_seqlens_kv,
39+
int64_t max_seqlen_q,
40+
int64_t max_seqlen_kv,
41+
const double softmax_scale,
42+
bool is_causal,
43+
at::Tensor& block_table,
44+
const c10::optional<at::Tensor>& alibi_slopes);
45+
3246
using single_query_cached_kv_attention_fn = void (*)(
3347
at::Tensor& out, // [num_seqs, num_heads, head_size]
3448
at::Tensor& query, // [num_seqs, num_heads, head_size]
@@ -49,10 +63,25 @@ using reshape_and_cache_fn = void (*)(
4963
at::Tensor& value_cache,
5064
at::Tensor& slot_mapping);
5165

66+
using flash_attn_var_len_fn = void (*)(
67+
at::Tensor& out,
68+
at::Tensor& query,
69+
at::Tensor& key,
70+
at::Tensor& value,
71+
at::Tensor& cu_seqlens_q,
72+
at::Tensor& cu_seqlens_kv,
73+
int64_t max_seqlen_q,
74+
int64_t max_seqlen_kv,
75+
const double softmax_scale,
76+
bool is_causal,
77+
at::Tensor& block_table,
78+
const c10::optional<at::Tensor>& alibi_slopes);
79+
5280
IPEX_DECLARE_DISPATCH(
5381
single_query_cached_kv_attention_fn,
5482
single_query_cached_kv_attention_kernel_stub);
5583
IPEX_DECLARE_DISPATCH(reshape_and_cache_fn, reshape_and_cache_kernel_stub);
84+
IPEX_DECLARE_DISPATCH(flash_attn_var_len_fn, flash_attn_var_len_kernel_stub);
5685

5786
} // namespace cpu
5887
} // namespace torch_ipex

csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,102 +8,14 @@
88

99
#include <ATen/Tensor.h>
1010
#include <aten/FlashAttention.h>
11+
#include <aten/utils/mkl_gemm.h>
1112
#include <torch/all.h>
1213
#include <torch/csrc/autograd/function.h>
1314
#include <limits>
1415
#include "../cpu/utils/isa_utils.h"
1516
#include "csrc/cpu/tpp/woq/tla.h"
16-
#include "mkl.h"
1717
#include "vec/vec.h"
1818

19-
inline void _mkl_gemm(
20-
const CBLAS_LAYOUT layout,
21-
const CBLAS_TRANSPOSE transa,
22-
const CBLAS_TRANSPOSE transb,
23-
const int& m,
24-
const int& n,
25-
const int& k,
26-
const float& alpha,
27-
const float* a,
28-
const int& lda,
29-
const float* b,
30-
const int& ldb,
31-
const float& beta,
32-
float* c,
33-
const int& ldc) {
34-
cblas_sgemm(
35-
layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
36-
}
37-
38-
inline void _mkl_gemm(
39-
const CBLAS_LAYOUT layout,
40-
const CBLAS_TRANSPOSE transa,
41-
const CBLAS_TRANSPOSE transb,
42-
const int& m,
43-
const int& n,
44-
const int& k,
45-
const double& alpha,
46-
const double* a,
47-
const int& lda,
48-
const double* b,
49-
const int& ldb,
50-
const double& beta,
51-
double* c,
52-
const int& ldc) {
53-
cblas_dgemm(
54-
layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
55-
}
56-
57-
inline void _mkl_gemm(
58-
const CBLAS_LAYOUT layout,
59-
const CBLAS_TRANSPOSE transa,
60-
const CBLAS_TRANSPOSE transb,
61-
const int& m,
62-
const int& n,
63-
const int& k,
64-
const float& alpha,
65-
const at::BFloat16* a,
66-
const int& lda,
67-
const at::BFloat16* b,
68-
const int& ldb,
69-
const float& beta,
70-
float* c,
71-
const int& ldc) {
72-
cblas_gemm_bf16bf16f32(
73-
layout,
74-
transa,
75-
transb,
76-
m,
77-
n,
78-
k,
79-
alpha,
80-
(const MKL_BF16*)(a),
81-
lda,
82-
(const MKL_BF16*)(b),
83-
ldb,
84-
beta,
85-
c,
86-
ldc);
87-
}
88-
89-
inline void _mkl_gemm(
90-
const CBLAS_LAYOUT layout,
91-
const CBLAS_TRANSPOSE transa,
92-
const CBLAS_TRANSPOSE transb,
93-
const int& m,
94-
const int& n,
95-
const int& k,
96-
const float& alpha,
97-
const at::Half* a,
98-
const int& lda,
99-
const at::Half* b,
100-
const int& ldb,
101-
const float& beta,
102-
float* c,
103-
const int& ldc) {
104-
TORCH_CHECK(false, "_mkl_gemm does not support FP16 yet");
105-
}
106-
10719
namespace torch_ipex {
10820
using namespace tpp;
10921
namespace cpu {

0 commit comments

Comments
 (0)