Skip to content

Commit 4cfb6d9

Browse files
committed
Add KVCache for long sequence && tuned comm for faster Addreduce
1 parent b29259a commit 4cfb6d9

File tree

9 files changed

+228
-91
lines changed

9 files changed

+228
-91
lines changed

src/common/kvcache_tensor.h

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
#include "allocator.h"
2323

24+
extern bool kvTrans();
25+
2426
/**
2527
* Tensor specially designed for KV Cache
2628
* Naturaly, it could be represented in the shape of [seq_length][batch_size][head_num][head_size]
@@ -92,13 +94,26 @@ class KVCacheTensor {
9294

9395
// Get a vector for a specified sequence
9496
T *getSequence(int seqIdx, int batchIdx, int headIdx) {
95-
return data + (seqIdx * batchSize + batchIdx) * (headNum * headSize) + headIdx * headSize;
97+
if (kvTrans()) {
98+
// [batchSize, headNum, seq, headSize] but also need to modify expand and reorder function
99+
return data + (uint64_t)batchIdx * headNum * maxSeqLen * headSize + (uint64_t)headIdx * maxSeqLen * headSize + (uint64_t)seqIdx * headSize;
100+
} else {
101+
// [seqLen, batchSize, headNum, headSize] but also need to modify expand and reorder function
102+
return data + (uint64_t)seqIdx * batchSize * headNum * headSize + (uint64_t)batchIdx * headNum * headSize + (uint64_t)headIdx * headSize;
103+
}
96104
}
97105

98106
// Get a head matrix, return the start address and the stride
99107
std::pair<T *, int> getHead(int batchIdx, int headIdx) {
100-
T *addr = data + batchIdx * headNum * headSize + headIdx * headSize;
101-
return std::make_pair(addr, batchSize * headNum * headSize);
108+
if (kvTrans()) {
109+
// [batchSize, headNum, seq, headSize] but also need to modify expand and reorder function
110+
T *addr = data + batchIdx * headNum * maxSeqLen * headSize + headIdx * maxSeqLen * headSize;
111+
return std::make_pair(addr, headSize);
112+
} else {
113+
// [seqLen, batchSize, headNum, headSize] but also need to modify expand and reorder function
114+
T *addr = data + (uint64_t)batchIdx * headNum * headSize + (uint64_t)headIdx * headSize;
115+
return std::make_pair(addr, batchSize * headNum * headSize);
116+
}
102117
}
103118

104119
/**
@@ -120,37 +135,34 @@ class KVCacheTensor {
120135
return;
121136
}
122137

138+
if (!kvTrans()) {
123139
#pragma omp parallel for
124-
for (int seq = 0; seq < seqLen; ++seq) {
125-
for (int b = batchSize - 1; b > 0; --b) {
126-
T *dst = getSequence(seq, b, 0);
127-
T *src = getSequence(seq, b / beamSize, 0);
128-
memcpy(dst, src, headNum * headSize * sizeof(T));
140+
for (int seq = 0; seq < seqLen; ++seq) {
141+
for (int b = batchSize - 1; b > 0; --b) {
142+
T *dst = getSequence(seq, b, 0);
143+
T *src = getSequence(seq, b / beamSize, 0);
144+
memcpy(dst, src, sizeof(T) * headNum * headSize);
145+
}
129146
}
147+
} else {
148+
printf("Unsupported kv tensor optimization [ENABLE_KV_TRANS] in beam search for now.\n");
149+
exit(-1);
130150
}
131151
}
132152

133153
void expandOneSequence(int userSideBS, int beamSize, int seq) {
134-
for (int b = batchSize - 1; b > 0; --b) {
135-
T *dst = getSequence(seq, b, 0);
136-
T *src = getSequence(seq, b / beamSize, 0);
137-
memcpy(dst, src, headNum * headSize * sizeof(T));
154+
if (!kvTrans()) {
155+
for (int b = batchSize - 1; b > 0; --b) {
156+
T *dst = getSequence(seq, b, 0);
157+
T *src = getSequence(seq, b / beamSize, 0);
158+
memcpy(dst, src, sizeof(T) * headNum * headSize);
159+
}
160+
} else {
161+
printf("Unsupported kv tensor optimization [ENABLE_KV_TRANS] in beam search for now.\n");
162+
exit(-1);
138163
}
139164
}
140165

141-
// Below implementation could be a little faster (100.6 vs. 100.9), but also need to modify expand and reorder function
142-
143-
// // Get a vector for a specified sequence
144-
// T *getSequence(int seqIdx, int batchIdx, int headIdx) {
145-
// return data + batchIdx * headNum * maxSeqLen * headSize + headIdx * maxSeqLen * headSize + seqIdx * headSize;
146-
// }
147-
148-
// // Get a head matrix, return the start address and the stride
149-
// std::pair<T *, int> getHead(int batchIdx, int headIdx) {
150-
// T *addr = data + batchIdx * headNum * maxSeqLen * headSize + headIdx * maxSeqLen * headSize;
151-
// return std::make_pair(addr, headSize);
152-
// }
153-
154166
private:
155167
int maxSeqLen;
156168
int batchSize;
@@ -159,4 +171,4 @@ class KVCacheTensor {
159171

160172
T *data;
161173
uint64_t allocSize;
162-
};
174+
};

src/layers/attention.h

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ class Attention {
571571
int scoreStride = pastSeqLen > 0 ? (pastSeqLen + ctx->inputSeqLen + 15) / 16 * 16 : ctx->inputSeqLen;
572572
auto bufSizeRequired = ctx->numThreads * mBlockSize * scoreStride;
573573
if (bufSizeRequired > ctx->getScoreCapacity()) {
574-
scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", bufSizeRequired * sizeof(float));
574+
scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", sizeof(float) * bufSizeRequired);
575575
}
576576

577577
#pragma omp parallel for collapse(3)
@@ -680,7 +680,7 @@ class Attention {
680680
}
681681

682682
float *shardedOut = (float *)SimpleMemPool::instance().getBuffer(
683-
"shardedOutput", totalTasks * ctx->attHeadSize * sizeof(float));
683+
"shardedOutput", sizeof(float) * totalTasks * ctx->attHeadSize);
684684

685685
#pragma omp parallel for collapse(3)
686686
for (int b = 0; b < batchSize; ++b) {
@@ -835,6 +835,7 @@ class Attention {
835835
// TODO: kv dtype conversion for prefixSharing
836836
AttnT *k, *v;
837837
int kvStride;
838+
// convert to AttnT forcely for accelerating purpose
838839
if constexpr (!std::is_same_v<AttnT, ImT>) {
839840
//Timer tmc(true, "convert KV matrix into bf16");
840841
kvStride = kvCols * 2;
@@ -866,28 +867,10 @@ class Attention {
866867

867868
// [batch, src, head, headsize]
868869
scaledDpAttention<AttnT>(query.Data(), k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads,
869-
respKVHeads, headSize, result.Data(), qkvCols, kvStride, result.Stride());
870+
respKVHeads, headSize, result.Data(), query.Stride(), kvStride, result.Stride());
870871

871-
// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
872-
// When M dimension is split, also multiple tasks per copy, so do copy seperately
873-
#pragma omp parallel for collapse(3)
874-
for (uint64_t b = 0; b < batchSize; ++b) {
875-
for (uint64_t i = 0; i < (this->endKVHead - this->startKVHead); ++i) {
876-
// Copy current key/value to cached keys/values
877-
// Re-layout is needed: (bs, seq=1, hidden_size) -> (seq=1, bs, hidden_size)
878-
// Be noted: for group attention, the key/value is less than query
879-
for (uint64_t seq = 0; seq < tgtLen; ++seq) {
880-
auto srcK = key.Data() + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
881-
auto dstK = presentKey.getSequence(pastSeqLen + seq, b, i);
882-
883-
auto srcV = value.Data() + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
884-
auto dstV = presentValue.getSequence(pastSeqLen + seq, b, i);
885-
886-
xft::copy(dstK, srcK, headSize);
887-
xft::copy(dstV, srcV, headSize);
888-
}
889-
}
890-
}
872+
// copy current key/values to cache
873+
copyKVCache(ctx, key, value, presentKey, presentValue, pastSeqLen);
891874
}
892875

893876
// scaled dot-product attention: bmm1 + softmax + bmm2
@@ -908,9 +891,9 @@ class Attention {
908891

909892
int numArr = 7;
910893
int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk;
911-
float *thrBuf = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", nth * arrStride * sizeof(float));
894+
float *thrBuf = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", sizeof(float) * nth * arrStride);
912895
float **thrPtrBuf
913-
= (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", nth * numArr * sizeof(float *));
896+
= (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", sizeof(float *) * nth * numArr);
914897

915898
float **preSum = thrPtrBuf;
916899
float **sum = thrPtrBuf + nth;
@@ -930,7 +913,7 @@ class Attention {
930913
qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i;
931914
}
932915

933-
#pragma omp parallel for collapse(3)
916+
#pragma omp parallel for collapse(3) schedule(dynamic)
934917
for (uint64_t i = 0; i < batchSize; ++i) {
935918
for (int j = 0; j < numQHead; ++j) {
936919
for (int m = 0; m < srcLen; m += srcBlk) {
@@ -968,6 +951,11 @@ class Attention {
968951
for (int b = 0; b < tgtLen; b += tgtBlk) {
969952
int kvRealBlk = std::min(tgtBlk, tgtLen - b);
970953
// TODO: mask out
954+
if (enableSkipMsk() && DecoderUtil::skipMskAttn(attnMsk + b, qRealBlk, kvRealBlk, tgtLen)) {
955+
// printf("Skip bs %d head %d src %d tgt %d\n", i, j, m, b);
956+
break;
957+
}
958+
971959
const AttnT *kBlk = k + b * kvStride;
972960
const AttnT *vBlk = v + b * kvStride;
973961

src/models/env_config.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,39 @@ bool enableCATMLP() {
2323
return catMlp == 1;
2424
}
2525

26+
bool tunedComm() {
27+
static int tunedComm = -1;
28+
if (tunedComm == -1) {
29+
tunedComm = (getenv("ENABLE_TUNED_COMM") ? atoi(getenv("ENABLE_TUNED_COMM")) : 1);
30+
if (tunedComm == 1)
31+
printf("ENABLE_TUNED_COMM is enabled for faster reduceAdd.\n");
32+
}
33+
return tunedComm == 1;
34+
}
35+
2636
int getFlashThresh() {
2737
static int envFlashThresh = -1;
2838
if (envFlashThresh == -1)
2939
envFlashThresh = (getenv("FLASH_ATTN_THRESHOLD") ? atoi(getenv("FLASH_ATTN_THRESHOLD")) : 1024);
3040
return envFlashThresh;
3141
}
42+
43+
bool enableSkipMsk() {
44+
static int skipMsk = -1;
45+
if (skipMsk == -1) {
46+
skipMsk = (getenv("ENABLE_SKIP_MASK") ? atoi(getenv("ENABLE_SKIP_MASK")) : 0);
47+
if (skipMsk == 1)
48+
printf("ENABLE_SKIP_MASK is enabled for ignoring mask Q*K.\n");
49+
}
50+
return skipMsk == 1;
51+
}
52+
53+
bool kvTrans() {
54+
static int kvTrans = -1;
55+
if (kvTrans == -1) {
56+
kvTrans = (getenv("ENABLE_KV_TRANS") ? atoi(getenv("ENABLE_KV_TRANS")) : 0);
57+
// if (kvTrans == 1)
58+
// printf("ENABLE_KV_TRANS is enabled for kv cache optimization.\n");
59+
}
60+
return kvTrans == 1;
61+
}

src/models/kvcache_manager.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,31 @@ void KVCacheManager<KVCacheT>::expandPrefixCache(int layerId, int userSideBS, in
141141
int headNum = dstTensors[0]->getHeadNum();
142142
int headSize = dstTensors[0]->getHeadSize();
143143

144+
if (!kvTrans()) {
144145
#pragma omp parallel for collapse(2)
145-
for (int i = 0; i < 2; ++i) {
146-
for (int seq = 0; seq < seqLen; ++seq) {
147-
auto *src = srcTensors[i]->getSequence(seq, 0, 0);
148-
for (int b = userSideBS - 1; b >= 0; --b) {
149-
auto *dst = dstTensors[i]->getSequence(seq, b, 0);
150-
memcpy(dst, src, headNum * headSize * sizeof(KVCacheT));
146+
for (int i = 0; i < 2; ++i) {
147+
for (int seq = 0; seq < seqLen; ++seq) {
148+
auto *src = srcTensors[i]->getSequence(seq, 0, 0);
149+
for (int b = userSideBS - 1; b >= 0; --b) {
150+
auto *dst = dstTensors[i]->getSequence(seq, b, 0);
151+
memcpy(dst, src, sizeof(KVCacheT) * headNum * headSize);
152+
}
151153
}
152154
}
155+
} else {
156+
printf("Unsupported kv tensor optimization [ENABLE_KV_TRANS] in Prefix mode for now.\n");
157+
exit(-1);
153158
}
154159
}
155160

156161
// Reorder cached keys and values
157162
// TODO: move to KVCacheTensor is better
158163
template <typename KVCacheT>
159164
void KVCacheManager<KVCacheT>::reorderCache(int *idx, int size, int initSeqLen, int accSeqLen) {
165+
if (kvTrans()) {
166+
printf("Unsupported kv tensor optimization [ENABLE_KV_TRANS] in beam search for now.\n");
167+
exit(-1);
168+
}
160169
// Reorder for all the layers
161170
#pragma omp parallel for
162171
for (int layer = 0; layer < this->layers; ++layer) {
@@ -251,4 +260,4 @@ void KVCacheManager<KVCacheT>::reorderCache(int *idx, int size, int initSeqLen,
251260

252261
template class KVCacheManager<float16_t>;
253262
template class KVCacheManager<bfloat16_t>;
254-
template class KVCacheManager<float>;
263+
template class KVCacheManager<float>;

src/utils/decoder_util.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
#include "transformer_ctx.h"
2929
#include "xdnn.h"
3030

31-
int getFlashThresh();
32-
bool enableCATMLP();
31+
extern int getFlashThresh();
32+
extern bool enableCATMLP();
33+
extern bool enableSkipMsk();
3334

3435
class DecoderUtil {
3536
public:
@@ -580,4 +581,14 @@ class DecoderUtil {
580581
sgemm((T *)AB, C, expABC, m, n, k, k, vStride, n, false, false);
581582
updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride);
582583
}
584+
585+
static bool skipMskAttn(const float *attnMask, int m, int n, int stride) {
586+
float lowest = std::numeric_limits<float>::lowest();
587+
// left bottom is lowest
588+
if (attnMask[(m - 1)* stride] == lowest)
589+
return true;
590+
else
591+
return false;
592+
}
593+
583594
};

src/utils/matmul_helper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ class MMHelper {
10671067
if constexpr (std::is_same_v<InT, bfloat16_t>) {
10681068
TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_residential");
10691069
#pragma omp parallel for collapse(2)
1070-
for (int i = 0; i < M; ++i) {
1070+
for (uint64_t i = 0; i < M; ++i) {
10711071
for (int j = 0; j < N; ++j) {
10721072
auto remain = N - j;
10731073
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
@@ -1082,7 +1082,7 @@ class MMHelper {
10821082
if (M > AMXThresholdM) {
10831083
TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_residential");
10841084
#pragma omp parallel for collapse(2)
1085-
for (int i = 0; i < M; ++i) {
1085+
for (uint64_t i = 0; i < M; ++i) {
10861086
for (int j = 0; j < N; ++j) {
10871087
res[i * ldres + j] = res[i * ldres + j] * gamma;
10881088
}
@@ -1624,7 +1624,7 @@ class MMHelper {
16241624
if (C == res) {
16251625
scale_mem = memory(scale_md, *engine);
16261626
#pragma omp parallel for
1627-
for (int i = 0; i < M; ++i) {
1627+
for (uint64_t i = 0; i < M; ++i) {
16281628
memcpy((Tin *)scale_mem.get_data_handle() + i * N, res + i * ldres, N * sizeof(Tin));
16291629
}
16301630
} else {

0 commit comments

Comments
 (0)