Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions src/layers/rotary_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

#include "compile_util.h"

static int max_seq_len_cached = -1;
static int inv_freq_size = -1;
static float *inv_freq;
static float *emb_cos = nullptr;
static float *emb_sin = nullptr;

Expand All @@ -29,38 +27,35 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position
if (!initialized) {
initialized = true;

max_seq_len_cached = max_position_embeddings;
inv_freq_size = (dim + 1) / 2;
inv_freq = (float *)malloc(inv_freq_size * sizeof(float));
float *inv_freq = (float *)malloc(inv_freq_size * sizeof(float));
for (size_t i = 0; i < inv_freq_size; i++) {
inv_freq[i] = 1.0 / pow(base, float(i * 2) / dim);
}

llamaCalEmb();
llamaCalEmb(inv_freq, max_position_embeddings);
} else if (dim != inv_freq_size * 2) {
printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, inv_freq_size);
exit(-1);
}
};

void LlamaRotaryEmbedding::llamaCalEmb() {
emb_cos = (float *)aligned_alloc(64, max_seq_len_cached * (inv_freq_size * 2) * sizeof(float));
emb_sin = (float *)aligned_alloc(64, max_seq_len_cached * (inv_freq_size * 2) * sizeof(float));
void LlamaRotaryEmbedding::llamaCalEmb(const float *inv_freq, const int max_position_embeddings) {
emb_cos = (float *)aligned_alloc(64, max_position_embeddings * inv_freq_size * sizeof(float));
emb_sin = (float *)aligned_alloc(64, max_position_embeddings * inv_freq_size * sizeof(float));

#pragma omp parallel for
for (size_t i = 0; i < max_seq_len_cached; i++) {
float *pcos = emb_cos + i * inv_freq_size * 2;
float *psin = emb_sin + i * inv_freq_size * 2;
for (size_t i = 0; i < max_position_embeddings; i++) {
float *pcos = emb_cos + i * inv_freq_size;
float *psin = emb_sin + i * inv_freq_size;

for (size_t j = 0; j < inv_freq_size; j++) {
float tmp = i * inv_freq[j];
float cos_tmp = std::cos(tmp);
float sin_tmp = std::sin(tmp);

pcos[j] = cos_tmp;
pcos[j + inv_freq_size] = cos_tmp;
psin[j] = sin_tmp;
psin[j + inv_freq_size] = sin_tmp;
}
}
}
Expand All @@ -84,14 +79,31 @@ void LlamaRotaryEmbedding::llamaCalEmb() {
// position_ids: an array in the size of seq_len
// query and key is the matrix like below:
//
// |<------------------------------ head_size * head_num --------------------------------->|
// |_head_size|___________________________________________________________________________________
// | | | | | | | | | ^
// | | | | | | | | | |
// | | | | | | | | | bs*seq_len
// | | | | | | | | | |
// | | | | | | | | | |
// |__________|__________|__________|__________|__________|__________|__________|__________|____v__
// |<------------------------------ head_num * head_size --------------------------------->|
// |_head_size|_____________________________________________________________________________ _ _ _ _
// | | | | | | | | | ^
// | | | | | | | | | |
// | | | | | | | | | bs*seq_len
// | | | | | | | | | |
// | | | | | | | | | |
// |__________|__________|__________|__________|__________|__________|__________|__________|_ _ _v_
//
// inv_freq:
// _____
// |_____| 1
// head_size/2
//
// emb_cos: emb_sin:
// _____ _____
// | | | |
// | | | |
// | | | |
// | | | | max_position_embeddings
// | | | |
// | | | |
// |_____| |_____|
// head_size/2 head_size/2

void LlamaRotaryEmbedding::forward(
float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) {
int dim = inv_freq_size * 2;
Expand All @@ -115,22 +127,22 @@ void LlamaRotaryEmbedding::forward(
for (int bs = 0; bs < batchSize; ++bs) {
for (int seq = 0; seq < seqLen; ++seq) {
int pos = positionIds[seq];
float *pcos = emb_cos + pos * dim;
float *psin = emb_sin + pos * dim;
float *pcos = emb_cos + pos * half;
float *psin = emb_sin + pos * half;

float *q = query + bs * seqLen * qStride + seq * qStride + head * dim;
float *k = key + bs * seqLen * kStride + seq * kStride + head * dim;
#pragma omp simd
for (int i = 0; i < half; ++i) {
if (head < qHeads) {
auto q1 = q[i];
q[i] = q[i] * pcos[i] - q[i + half] * psin[i];
q[i + half] = q[i + half] * pcos[i + half] + q1 * psin[i + half];
q[i] = q1 * pcos[i] - q[i + half] * psin[i];
q[i + half] = q[i + half] * pcos[i] + q1 * psin[i];
}
if (head < kHeads) {
auto k1 = k[i];
k[i] = k[i] * pcos[i] - k[i + half] * psin[i];
k[i + half] = k[i + half] * pcos[i + half] + k1 * psin[i + half];
k[i] = k1 * pcos[i] - k[i + half] * psin[i];
k[i + half] = k[i + half] * pcos[i] + k1 * psin[i];
}
}
}
Expand All @@ -155,8 +167,8 @@ void LlamaRotaryEmbedding::forward(
for (int bs = 0; bs < batchSize; ++bs) {
for (int seq = 0; seq < seqLen; ++seq) {
int pos = positionIds[seq];
float *pcos = emb_cos + pos * dim;
float *psin = emb_sin + pos * dim;
float *pcos = emb_cos + pos * half;
float *psin = emb_sin + pos * half;

bfloat16_t *q = query + bs * seqLen * qStride + seq * qStride + head * dim;
bfloat16_t *k = key + bs * seqLen * kStride + seq * kStride + head * dim;
Expand All @@ -167,9 +179,7 @@ void LlamaRotaryEmbedding::forward(
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

__m512 pCosVec = _mm512_maskz_loadu_ps(mask, &pcos[i]);
__m512 pCosHalfVec = _mm512_maskz_loadu_ps(mask, &pcos[i + half]);
__m512 pSinVec = _mm512_maskz_loadu_ps(mask, &psin[i]);
__m512 pSinHalfVec = _mm512_maskz_loadu_ps(mask, &psin[i + half]);

// Compute something like:
// q[i] = q[i] * pcos[i] - q[i + half] * psin[i];
Expand All @@ -178,7 +188,7 @@ void LlamaRotaryEmbedding::forward(
__m512 qVec = bfloat16_t::cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, &q[i]));
__m512 qHalfVec = bfloat16_t::cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, &q[i + half]));
__m512 qNew = _mm512_fmsub_ps(qVec, pCosVec, _mm512_mul_ps(qHalfVec, pSinVec));
__m512 qHalfNew = _mm512_fmadd_ps(qHalfVec, pCosHalfVec, _mm512_mul_ps(qVec, pSinHalfVec));
__m512 qHalfNew = _mm512_fmadd_ps(qHalfVec, pCosVec, _mm512_mul_ps(qVec, pSinVec));
_mm256_mask_storeu_epi16(&q[i], mask, bfloat16_t::cvt_fp32_to_bf16(qNew));
_mm256_mask_storeu_epi16(&q[i + half], mask, bfloat16_t::cvt_fp32_to_bf16(qHalfNew));
}
Expand All @@ -187,7 +197,7 @@ void LlamaRotaryEmbedding::forward(
__m512 kVec = bfloat16_t::cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, &k[i]));
__m512 kHalfVec = bfloat16_t::cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, &k[i + half]));
__m512 kNew = _mm512_fmsub_ps(kVec, pCosVec, _mm512_mul_ps(kHalfVec, pSinVec));
__m512 kHalfNew = _mm512_fmadd_ps(kHalfVec, pCosHalfVec, _mm512_mul_ps(kVec, pSinHalfVec));
__m512 kHalfNew = _mm512_fmadd_ps(kHalfVec, pCosVec, _mm512_mul_ps(kVec, pSinVec));
_mm256_mask_storeu_epi16(&k[i], mask, bfloat16_t::cvt_fp32_to_bf16(kNew));
_mm256_mask_storeu_epi16(&k[i + half], mask, bfloat16_t::cvt_fp32_to_bf16(kHalfNew));
}
Expand Down
2 changes: 1 addition & 1 deletion src/layers/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LlamaRotaryEmbedding {
bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds);

private:
void llamaCalEmb();
void llamaCalEmb(const float *inv_freq, const int max_position_embeddings);

private:
static bool initialized;
Expand Down