@@ -571,7 +571,7 @@ class Attention {
571571 int scoreStride = pastSeqLen > 0 ? (pastSeqLen + ctx->inputSeqLen + 15 ) / 16 * 16 : ctx->inputSeqLen ;
572572 auto bufSizeRequired = ctx->numThreads * mBlockSize * scoreStride;
573573 if (bufSizeRequired > ctx->getScoreCapacity ()) {
574- scoreBuf = (float *)SimpleMemPool::instance ().getBuffer (" scoreBuf" , bufSizeRequired * sizeof (float ));
574+ scoreBuf = (float *)SimpleMemPool::instance ().getBuffer (" scoreBuf" , sizeof (float ) * bufSizeRequired );
575575 }
576576
577577#pragma omp parallel for collapse(3)
@@ -680,7 +680,7 @@ class Attention {
680680 }
681681
682682 float *shardedOut = (float *)SimpleMemPool::instance ().getBuffer (
683- " shardedOutput" , totalTasks * ctx->attHeadSize * sizeof ( float ) );
683+ " shardedOutput" , sizeof ( float ) * totalTasks * ctx->attHeadSize );
684684
685685#pragma omp parallel for collapse(3)
686686 for (int b = 0 ; b < batchSize; ++b) {
@@ -835,6 +835,7 @@ class Attention {
835835 // TODO: kv dtype conversion for prefixSharing
836836 AttnT *k, *v;
837837 int kvStride;
838+ // convert to AttnT forcely for accelerating purpose
838839 if constexpr (!std::is_same_v<AttnT, ImT>) {
839840 // Timer tmc(true, "convert KV matrix into bf16");
840841 kvStride = kvCols * 2 ;
@@ -866,28 +867,10 @@ class Attention {
866867
867868 // [batch, src, head, headsize]
868869 scaledDpAttention<AttnT>(query.Data (), k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads,
869- respKVHeads, headSize, result.Data (), qkvCols , kvStride, result.Stride ());
870+ respKVHeads, headSize, result.Data (), query. Stride () , kvStride, result.Stride ());
870871
871- // For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
872- // When M dimension is split, also multiple tasks per copy, so do copy seperately
873- #pragma omp parallel for collapse(3)
874- for (uint64_t b = 0 ; b < batchSize; ++b) {
875- for (uint64_t i = 0 ; i < (this ->endKVHead - this ->startKVHead ); ++i) {
876- // Copy current key/value to cached keys/values
877- // Re-layout is needed: (bs, seq=1, hidden_size) -> (seq=1, bs, hidden_size)
878- // Be noted: for group attention, the key/value is less than query
879- for (uint64_t seq = 0 ; seq < tgtLen; ++seq) {
880- auto srcK = key.Data () + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
881- auto dstK = presentKey.getSequence (pastSeqLen + seq, b, i);
882-
883- auto srcV = value.Data () + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
884- auto dstV = presentValue.getSequence (pastSeqLen + seq, b, i);
885-
886- xft::copy (dstK, srcK, headSize);
887- xft::copy (dstV, srcV, headSize);
888- }
889- }
890- }
872+ // copy current key/values to cache
873+ copyKVCache (ctx, key, value, presentKey, presentValue, pastSeqLen);
891874 }
892875
893876 // scaled dot-product attention: bmm1 + softmax + bmm2
@@ -908,9 +891,9 @@ class Attention {
908891
909892 int numArr = 7 ;
910893 int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk;
911- float *thrBuf = (float *)SimpleMemPool::instance ().getBuffer (" threadBuffers" , nth * arrStride * sizeof ( float ) );
894+ float *thrBuf = (float *)SimpleMemPool::instance ().getBuffer (" threadBuffers" , sizeof ( float ) * nth * arrStride );
912895 float **thrPtrBuf
913- = (float **)SimpleMemPool::instance ().getBuffer (" threadPtrBuffers" , nth * numArr * sizeof ( float *) );
896+ = (float **)SimpleMemPool::instance ().getBuffer (" threadPtrBuffers" , sizeof ( float *) * nth * numArr );
914897
915898 float **preSum = thrPtrBuf;
916899 float **sum = thrPtrBuf + nth;
@@ -930,7 +913,7 @@ class Attention {
930913 qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i;
931914 }
932915
933- #pragma omp parallel for collapse(3)
916+ #pragma omp parallel for collapse(3) schedule(dynamic)
934917 for (uint64_t i = 0 ; i < batchSize; ++i) {
935918 for (int j = 0 ; j < numQHead; ++j) {
936919 for (int m = 0 ; m < srcLen; m += srcBlk) {
@@ -968,6 +951,11 @@ class Attention {
968951 for (int b = 0 ; b < tgtLen; b += tgtBlk) {
969952 int kvRealBlk = std::min (tgtBlk, tgtLen - b);
970953 // TODO: mask out
954+ if (enableSkipMsk () && DecoderUtil::skipMskAttn (attnMsk + b, qRealBlk, kvRealBlk, tgtLen)) {
955+ // printf("Skip bs %d head %d src %d tgt %d\n", i, j, m, b);
956+ break ;
957+ }
958+
971959 const AttnT *kBlk = k + b * kvStride;
972960 const AttnT *vBlk = v + b * kvStride;
973961
0 commit comments