Skip to content

Commit ac3bd5c

Browse files
abenmaopujiang2018
authored andcommitted
[layers] Add bf16-type input/output support for flash attention (#252)
1 parent 95de632 commit ac3bd5c

File tree

6 files changed

+122
-174
lines changed

6 files changed

+122
-174
lines changed

src/layers/attention.h

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,16 @@ class Attention {
306306
}
307307

308308
// TODO: refine the logic (and support large inputSeqLen when pastSeqLen > 0)
309-
if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
310-
if (pastSeqLen == 0) {
309+
if (pastSeqLen == 0) {
310+
if (ctx->inputSeqLen >= getFlashThresh()) {
311+
flashAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
312+
} else if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
311313
selfAttentionBF16(ctx, query, key, value, imBuffer, presentKey, presentValue);
312314
} else {
313315
fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
314316
}
315317
} else {
316-
if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0)
317-
flashAttention(
318-
ctx, qkvGroupMatMul, outBuffer, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
319-
else { fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen); }
318+
fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
320319
}
321320
t4.release();
322321

@@ -809,11 +808,15 @@ class Attention {
809808
} // end for b
810809
}
811810

812-
template <typename KVCacheT, typename AttnT = bfloat16_t>
813-
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpBuf,
814-
hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
815-
const float *attnMask, int pastSeqLen) {
816-
811+
template <typename KVCacheT>
812+
void flashAttention(DecoderContext *ctx, hpj::Matrix<ImT> &query, hpj::Matrix<ImT> &key,
813+
hpj::Matrix<ImT> &value, hpj::Matrix<ImT> &result, KVCacheTensor<KVCacheT> &presentKey,
814+
KVCacheTensor<KVCacheT> &presentValue, const float *attnMask, int pastSeqLen) {
815+
#if defined(AVX512_BF16_WEIGHT_ONLY_BF16)
816+
using AttnT = bfloat16_t;
817+
#else
818+
using AttnT = float;
819+
#endif
817820
// How many heads this task should do
818821
int batchSize = ctx->batchSize;
819822
int respQHeads = this->endQHead - this->startQHead;
@@ -828,31 +831,41 @@ class Attention {
828831

829832
// TODO: kv dtype conversion for prefixSharing
830833
AttnT *k, *v;
831-
if constexpr (std::is_same_v<AttnT, bfloat16_t>) {
834+
int kvStride;
835+
if constexpr (!std::is_same_v<AttnT, ImT>) {
836+
//Timer tmc(true, "convert KV matrix into bf16");
837+
kvStride = kvCols * 2;
838+
AttnT *kvBuf = (AttnT *)SimpleMemPool::instance().getBuffer(
839+
"flashKVBuf", batchSize * srcLen * kvStride * sizeof(AttnT));
832840
#pragma omp parallel for collapse(3)
833841
for (uint64_t b = 0; b < batchSize; ++b)
834842
for (uint64_t seq = 0; seq < srcLen; ++seq)
835-
for (uint64_t i = qCols; i < qkvCols; i += headSize) {
836-
const float *srcPtr = qkvMatMul.Data() + b * srcLen * qkvCols + seq * qkvCols + i;
837-
bfloat16_t *dstPtr
838-
= (bfloat16_t *)tmpBuf.Data() + b * srcLen * kvCols * 2 + seq * kvCols * 2 + i - qCols;
839-
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
843+
for (uint64_t i = 0; i < kvCols * 2; i += headSize) {
844+
const ImT *srcPtr = key.Data() + b * srcLen * qkvCols + seq * qkvCols + i;
845+
AttnT *dstPtr
846+
= kvBuf + b * srcLen * kvStride + seq * kvStride + i;
847+
if constexpr (std::is_same_v<AttnT, bfloat16_t> && std::is_same_v<ImT, float>) {
848+
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
849+
} else if constexpr (std::is_same_v<AttnT, float> && std::is_same_v<ImT, bfloat16_t>) {
850+
bfloat16_t::cvt_bfloat16_to_float(srcPtr, dstPtr, headSize);
851+
} else {
852+
printf("Not supported Type in Flash Attention yet\n");
853+
exit(-1);
854+
}
840855
}
841856

842-
k = (AttnT *)tmpBuf.Data();
843-
v = (AttnT *)tmpBuf.Data() + kvCols;
857+
k = kvBuf;
858+
v = kvBuf + kvCols;
844859
} else {
845-
k = qkvMatMul.Data() + respQHeads * headSize;
846-
v = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize;
860+
kvStride = qkvCols;
861+
k = key.Data();
862+
v = value.Data();
847863
}
848864

849-
float *query = qkvMatMul.Data();
850865
// [batch, src, head, headsize]
851-
scaledDpAttention<AttnT>(query, k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
852-
headSize, result.Data(), qkvCols, kvCols * 2, ctx->hiddenSize);
866+
scaledDpAttention<AttnT>(query.Data(), k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
867+
headSize, result.Data(), qkvCols, kvStride, result.Stride());
853868

854-
float *key = qkvMatMul.Data() + respQHeads * headSize;
855-
float *value = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize;
856869
// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
857870
// When M dimension is split, also multiple tasks per copy, so do copy seperately
858871
#pragma omp parallel for collapse(3)
@@ -862,10 +875,10 @@ class Attention {
862875
// Re-layout is needed: (bs, seq=1, hidden_size) -> (seq=1, bs, hidden_size)
863876
// Be noted: for group attention, the key/value is less than query
864877
for (uint64_t seq = 0; seq < tgtLen; ++seq) {
865-
auto srcK = key + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
878+
auto srcK = key.Data() + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
866879
auto dstK = presentKey.getSequence(pastSeqLen + seq, b, i);
867880

868-
auto srcV = value + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
881+
auto srcV = value.Data() + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
869882
auto dstV = presentValue.getSequence(pastSeqLen + seq, b, i);
870883

871884
xft::copy(dstK, srcK, headSize);
@@ -877,8 +890,8 @@ class Attention {
877890

878891
// scaled dot-product attention: bmm1 + softmax + bmm2
879892
template <typename AttnT>
880-
void scaledDpAttention(const float *query, const AttnT *key, const AttnT *value, const float *attnMask, float scale,
881-
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output,
893+
void scaledDpAttention(const ImT *query, const AttnT *key, const AttnT *value, const float *attnMask, float scale,
894+
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, ImT *output,
882895
int qStride, int kvStride, int stride) {
883896
// output = trans(softmax(query * trans(key)) * value)
884897
int nth = omp_get_max_threads();
@@ -916,17 +929,17 @@ class Attention {
916929
}
917930

918931
#pragma omp parallel for collapse(3)
919-
for (int i = 0; i < batchSize; ++i) {
932+
for (uint64_t i = 0; i < batchSize; ++i) {
920933
for (int j = 0; j < numQHead; ++j) {
921934
for (int m = 0; m < srcLen; m += srcBlk) {
922935
int tid = omp_get_thread_num();
923936

924937
int qRealBlk = std::min(srcBlk, srcLen - m);
925938
uint64_t srcOff = i * srcLen * qStride + j * headSize;
926939
uint64_t outOff = i * srcLen * stride + j * headSize;
927-
const float *qbuf = query + srcOff + m * qStride;
940+
const ImT *qbuf = query + srcOff + m * qStride;
928941
AttnT *q = (AttnT *)qArr[tid];
929-
float *out = output + outOff + m * stride;
942+
ImT *out = output + outOff + m * stride;
930943

931944
// reset out
932945
for (int ii = 0; ii < qRealBlk; ++ii) {

src/layers/mlp_chatglm2.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ class ChatGLM2MLP : public LlamaMLP<WeiT> {
3838
auto range = SplitUtil::getTaskRange(intermediateSize, ctx->numSplit, ctx->splitIdx);
3939
int colSplit = range.second - range.first;
4040

41-
setMLPOPTConfig();
42-
if (!enableCATMLP) {
41+
if (!enableCATMLP()) {
4342
OriWeiT *gateW = (OriWeiT *)malloc(hiddenSize * colSplit * sizeof(OriWeiT));
4443
OriWeiT *upW = (OriWeiT *)malloc(hiddenSize * colSplit * sizeof(OriWeiT));
4544
if (trans) {
@@ -93,14 +92,9 @@ class ChatGLM2MLP : public LlamaMLP<WeiT> {
9392
}
9493
}
9594
// Horizontally split the down weight
96-
if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t>) {
97-
ctx->mmHelper->convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, nullptr, nullptr, false,
98-
this->downWeight, this->downWeightScale, this->downWeightZero, this->gateWeightSum);
99-
} else {
100-
ctx->mmHelper->convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, nullptr, nullptr, false,
101-
convertedDownWeight, this->downWeightScale, this->downWeightZero, this->downWeightSum);
102-
ctx->mmHelper->packWeight(trans, convertedDownWeight, this->downWeight);
103-
}
95+
ctx->mmHelper->convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, nullptr, nullptr, false,
96+
convertedDownWeight, this->downWeightScale, this->downWeightZero, this->downWeightSum);
97+
ctx->mmHelper->packWeight(trans, convertedDownWeight, this->downWeight);
10498
#ifdef DEBUG
10599
this->dbg.debugPrint("convertedGateWeight [%d, %d](%d):\n", convertedGateWeight.Rows(),
106100
convertedGateWeight.Cols(), convertedGateWeight.Stride());

src/layers/mlp_llama.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@
1717

1818
#include <unordered_map>
1919

20-
bool enableCATMLP;
21-
bool enableCBLASMLP;
22-
23-
void setMLPOPTConfig() {
24-
enableCATMLP = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1);
25-
enableCBLASMLP = (getenv("ENABLE_CBLAS_MLP") ? atoi(getenv("ENABLE_CBLAS_MLP")) : 0);
26-
}
27-
2820
namespace xft {
2921

3022
void invokeMLPLLaMA(DataType dt, int numTokens, int hiddenSize, int intermediateSize, void *output, int outputStride,

src/layers/mlp_llama.h

Lines changed: 19 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
#include "singleton.h"
2424
#include "timeline.h"
2525

26-
extern bool enableCATMLP;
27-
extern bool enableCBLASMLP;
28-
void setMLPOPTConfig();
2926
// C++ implementation for the python code in modeling_llama.py:
3027
// residual = hidden_states
3128
// hidden_states = self.post_attention_layernorm(hidden_states)
@@ -65,8 +62,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
6562
ctx->mmHelper->convertWeight(ctx, trans, hiddenSize, imSize, upW, upS, upZ, true, quantizedUpWeight,
6663
upWeightScale, upWeightZero, upWeightSum);
6764

68-
setMLPOPTConfig();
69-
if (!enableCATMLP) {
65+
if (!enableCATMLP()) {
7066
gateWeight.Resize(hiddenSize, it.second - it.first);
7167
upWeight.Resize(hiddenSize, it.second - it.first);
7268
ctx->mmHelper->packWeight(trans, quantizedGateWeight, gateWeight);
@@ -82,14 +78,9 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
8278
ctx->mmHelper->packWeight(trans, quantizedCatWeights, catWeights);
8379
}
8480
// Horizontally split the down weight
85-
if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t>) {
86-
ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false, downWeight,
87-
downWeightScale, downWeightZero, downWeightSum);
88-
} else {
89-
ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false,
90-
quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum);
91-
ctx->mmHelper->packWeight(trans, quantizedDownWeight, downWeight);
92-
}
81+
ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false,
82+
quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum);
83+
ctx->mmHelper->packWeight(trans, quantizedDownWeight, downWeight);
9384

9485
#ifdef DEBUG
9586
dbg.debugPrint("quantizedGateWeight:\n");
@@ -137,7 +128,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
137128
dbg.dumpMatrix(normBuffer);
138129
#endif
139130

140-
if (!enableCATMLP) {
131+
if (!enableCATMLP()) {
141132
hpj::Matrix<ImT> imBuffer(
142133
(ImT *)ctx->imOut.Data(), ctx->imOut.Rows(), ctx->imOut.Cols(), ctx->imOut.Stride());
143134
gateProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer);
@@ -165,31 +156,19 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
165156
hpj::Matrix<ImT> imBuffer((ImT *)ctx->imOut.Data(), M, N, N);
166157

167158
// Need to allocate extra buffer as oneDNN does not support the case of stride > cols
168-
if constexpr (std::is_same_v<ImT, bfloat16_t>) {
169-
const int cols = N / 2;
170-
auto bufSize = M * cols * sizeof(ImT);
171-
ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize);
172-
hpj::Matrix<ImT> siluBuf(t, M, cols, cols);
173-
174-
catGateUpProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, siluBuf);
175-
#ifdef DEBUG
176-
dbg.debugPrint("gateUp output:\n");
177-
dbg.dumpMatrix(siluBuf);
178-
#endif
179-
downProj(ctx, siluBuf, outBuffer, inBuffer, ctx->splitIdx == 0);
180-
}
159+
const int cols = N / 2;
160+
auto bufSize = M * cols * sizeof(ImT);
161+
ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize);
162+
hpj::Matrix<ImT> siluBuf(t, M, cols, cols);
181163

182-
// Use imBuffer as silu buffer
183-
else {
184-
catGateUpProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, imBuffer);
164+
catGateUpProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, siluBuf);
185165
#ifdef DEBUG
186-
dbg.debugPrint("catWeights:\n");
187-
dbg.dumpMatrix(catWeights);
188-
dbg.debugPrint("gateUp output:\n");
189-
dbg.dumpMatrix(imBuffer);
166+
dbg.debugPrint("catWeights:\n");
167+
dbg.dumpMatrix(catWeights);
168+
dbg.debugPrint("gateUp output:\n");
169+
dbg.dumpMatrix(siluBuf);
190170
#endif
191-
downProj(ctx, imBuffer, outBuffer, inBuffer, ctx->splitIdx == 0);
192-
}
171+
downProj(ctx, siluBuf, outBuffer, inBuffer, ctx->splitIdx == 0);
193172
}
194173

195174
#ifdef DEBUG
@@ -248,7 +227,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
248227
TimeLine t("DownProj");
249228

250229
assert(input.Rows() == output.Rows());
251-
if (!enableCATMLP)
230+
if (!enableCATMLP())
252231
assert(input.Cols() == downWeight.Rows());
253232
else
254233
assert(input.Cols() == 2 * downWeight.Rows());
@@ -266,62 +245,10 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
266245
const InT *R = residential.Data();
267246

268247
if (isMaster) {
269-
// TODO: enable below code (currently disabled as hard to get tmpBuf from pre-alloced memory)
270-
// if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t>) {
271-
// computeProjBF16(A, B, C, M, N, K, lda, ldc, ldc, R, ldr, tmpBuf, ldt);
272-
// }
273-
{
274-
ctx->mmHelper->compute_residential(
275-
false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, NULL, R, ldr);
276-
}
248+
ctx->mmHelper->compute_residential(
249+
false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc, NULL, R, ldr);
277250
} else {
278-
// if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t>) {
279-
// computeProjBF16(A, B, C, M, N, K, lda, ldc, ldc, nullptr, 0, tmpBuf, ldt);
280-
// }
281-
{
282-
ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc);
283-
}
284-
}
285-
}
286-
287-
// C = (R == nullptr ? A * B : A * B + R)
288-
// T: temporary buffer if C is not in float
289-
void computeProjBF16(const ImT *A, const WeiT *B, OutT *C, int M, int N, int K, int lda, int ldb, int ldc,
290-
const InT *R, int ldr, float *T, int ldt) {
291-
int alpha = 1.0;
292-
int beta = 0.0;
293-
294-
// MKL needs float as output, use T (temporary buffer) as output if C is not in float
295-
float *D = std::is_same_v<OutT, float> ? (float *)C : T;
296-
int ldd = std::is_same_v<OutT, float> ? ldc : ldt;
297-
298-
REQUIRES(D != nullptr, "Incorrect parameter in computeProjBF16.");
299-
300-
if (R != nullptr) {
301-
#pragma omp parallel for
302-
for (uint64_t i = 0; i < M; ++i) {
303-
xft::copy(D + i * ldd, R + i * ldr, N);
304-
}
305-
beta = 1.0;
306-
}
307-
308-
int ldaH = lda * sizeof(ImT) / sizeof(bfloat16_t); // stride in bf16
309-
if constexpr (std::is_same_v<ImT, float>) {
310-
#pragma omp parallel for
311-
for (uint64_t i = 0; i < M; ++i) {
312-
bfloat16_t::cvt_float_to_bfloat16(A + i * lda, (bfloat16_t *)A + i * ldaH, K);
313-
}
314-
}
315-
316-
cblas_gemm_bf16bf16f32(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, (const MKL_BF16 *)(A), ldaH,
317-
(const MKL_BF16 *)(B), ldb, beta, D, ldd);
318-
319-
// Convert result from float to OutT
320-
if constexpr (!std::is_same_v<OutT, float>) {
321-
#pragma omp parallel for
322-
for (uint64_t i = 0; i < M; ++i) {
323-
xft::copy(C + i * ldc, D + i * ldd, N);
324-
}
251+
ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc);
325252
}
326253
}
327254

src/models/env_config.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (c) 2024 Intel Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// ============================================================================
15+
#include <cstdlib>
16+
#include <iostream>
17+
#include <stdlib.h>
18+
19+
bool enableCATMLP() {
20+
static int catMlp = -1;
21+
if (catMlp == -1)
22+
catMlp = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1);
23+
return catMlp == 1;
24+
}
25+
26+
int getFlashThresh() {
27+
static int envFlashThresh = -1;
28+
if (envFlashThresh == -1)
29+
envFlashThresh = (getenv("FLASH_ATTN_THRESHOLD") ? atoi(getenv("FLASH_ATTN_THRESHOLD")) : 1024);
30+
return envFlashThresh;
31+
}

0 commit comments

Comments
 (0)