Skip to content

[Kernel] Add GPU kernels and enable LLaMA model. #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aff6786
[Kernel] Add GPU kernels.
changqi1 May 7, 2024
0b1be0e
format code
changqi1 May 7, 2024
3ef6143
fix running issue.
changqi1 May 7, 2024
619e788
Add RmsNorm kernel.
changqi1 May 15, 2024
39ec0b4
Fix some issues.
changqi1 May 27, 2024
2ef7f7a
Optimze alloc
changqi1 May 27, 2024
29f01c1
Use unified onednn engine
changqi1 May 27, 2024
532445b
merge from main
changqi1 May 27, 2024
17f224e
Fix compile
changqi1 May 27, 2024
b61fe52
Fix onednn gemm issue
changqi1 May 27, 2024
c5b7ac7
Fix build
changqi1 Jun 3, 2024
0faa9de
Add fp16 rope kernels
changqi1 Jun 4, 2024
3cfa650
Fix attention UT issue.
changqi1 Jun 4, 2024
881bc78
Fix ICX build issue.
changqi1 Jun 4, 2024
c958a1d
Merge branch 'main' into changqing/feature/gpu_rope
changqi1 Jun 4, 2024
9b2af7e
Fix build.
changqi1 Jun 4, 2024
c034849
Add rmsNorm impl and XFT_DEBUG
changqi1 Jun 7, 2024
ec463a3
Update.
changqi1 Jun 7, 2024
69dd33a
update.
changqi1 Jun 7, 2024
5f43897
Add GPU memory to run kernels.
changqi1 Jun 12, 2024
23d2053
Add gpu matmul kernels
changqi1 Jun 12, 2024
b7dc9eb
Fix CPU build issue.
changqi1 Jun 13, 2024
daec9dd
fix
changqi1 Jun 13, 2024
c3e83f2
fix
changqi1 Jun 13, 2024
277de9b
fix
changqi1 Jun 13, 2024
dd1d3fb
fix
changqi1 Jun 13, 2024
15fc202
Fix build issue.
changqi1 Jun 13, 2024
726d356
Fix build issue.
changqi1 Jun 13, 2024
003c46b
Fix LN bug
changqi1 Jun 13, 2024
4cb98cf
Fix final LN
changqi1 Jun 13, 2024
6a85769
Fix 2
changqi1 Jun 13, 2024
f6e6e64
Fix 3
changqi1 Jun 13, 2024
5f93020
Done
changqi1 Jun 13, 2024
175c4dc
Finish
changqi1 Jun 14, 2024
8d35cfc
change macro GPU to XFT_GPU
changqi1 Jun 14, 2024
ea1679d
Add requirements-gpu.txt
changqi1 Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/common/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
#pragma once
#include <cstdio>
#include <cstdlib>
#include <sys/mman.h>
#include <cstring>
#include "environment.h"
#include <sys/mman.h>

#ifdef GPU
#include <CL/sycl.hpp>
#endif

namespace xft {

Expand All @@ -26,10 +31,22 @@ static inline bool is_thp_alloc(size_t nbytes) {
return (Env::getInstance().getTHPEnabled() && (nbytes >= g_thp_threshold));
}

static inline void *alloc(size_t nbytes, size_t alignment = 64) {
static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignment = 64) {
if (nbytes == 0) { return nullptr; }

void *data;
void *data = nullptr;

#ifdef GPU
if (device != nullptr) {
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
data = sycl::malloc_device<char>(nbytes, *gpu_queue);
if (data == nullptr) {
printf("Unable to allocate buffer with size of %zu in GPU.\n", nbytes);
exit(-1);
}
return data;
}
#endif

int err = posix_memalign(&data, alignment, nbytes);
if (err != 0) {
Expand All @@ -47,4 +64,28 @@ static inline void *alloc(size_t nbytes, size_t alignment = 64) {

return data;
}

static inline void dealloc(void *data, void *device = nullptr) {
#ifdef GPU
if (device != nullptr) {
sycl::free(data, *static_cast<sycl::queue *>(device));
return;
}
#endif

free(data);
}

static inline void memcopy(void *dst, const void *src, size_t size, void *device = nullptr) {
#ifdef GPU
if (device != nullptr) {
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
gpu_queue->memcpy(dst, src, size).wait();
return;
}
#endif

memcpy(dst, src, size);
}

} // namespace xft
71 changes: 71 additions & 0 deletions src/common/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <queue>
#include <unordered_map>

#include "allocator.h"
#include "environment.h"
#include "sampling_params.h"

Expand Down Expand Up @@ -81,6 +82,20 @@ class SequenceMeta {
, promptTokens(_inputSeqLen, 0)
, step(0) {}

SequenceMeta(int32_t _sequenceID, std::vector<int32_t> &_promptTokens)
: sequenceID(_sequenceID)
, inputSeqLen(_promptTokens.size())
, pastSeqLen(0)
, promptTokens(_promptTokens)
, step(0) {}

SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen)
: sequenceID(_sequenceID)
, inputSeqLen(_inputSeqLen)
, pastSeqLen(0)
, promptTokens(_inputSeqLen, 0)
, step(0) {}

~SequenceMeta() {}

int32_t getSequenceID() const { return sequenceID; }
Expand Down Expand Up @@ -207,6 +222,38 @@ class SequenceGroupMeta {
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens));
}
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen));
}
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, std::vector<int32_t> &_inputTokens) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens));
}
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen));
}
groupID = sequences[0].getSequenceID();
}

int32_t getGroupID() { return groupID; }

int32_t getGroupSize() { return samplingMeta.config.numBeams; }
Expand Down Expand Up @@ -272,6 +319,30 @@ class SequencePool {
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector<int32_t> &inputTokens, SamplingMeta &samplingMeta_) {
auto *group = new SequenceGroupMeta(sequenceID, inputTokens, samplingMeta_);
this->add(group);
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen, SamplingMeta &samplingMeta_) {
auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen, samplingMeta_);
this->add(group);
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector<int32_t> &inputTokens) {
auto *group = new SequenceGroupMeta(sequenceID, inputTokens);
this->add(group);
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen) {
auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen);
this->add(group);
return group;
}

bool add(SequenceGroupMeta *sequenceGroup, bool force = false) {
int32_t groupID = sequenceGroup->getGroupID();
bool isSuccess = false;
Expand Down
11 changes: 8 additions & 3 deletions src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ struct DecoderContext {
xft::Matrix<float> qkvMatMul; // query, key, value
xft::Matrix<float> imOut; // intermediate output

MMHelper *mmHelper;
MMHelper *mmHelper = nullptr;
void *device = nullptr;

std::string configPath;
INIReader configReader;
Expand Down Expand Up @@ -240,8 +241,12 @@ struct DecoderContext {
bool cached(const std::string &name) { return SimpleMemPool::instance().cached(name); }

template <typename T>
T *getBuffer(const std::string &name, size_t size, size_t alignment = 64) {
return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, alignment);
T *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) {
return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, device, alignment);
}

void freeBuffer(const std::string &name, void *device = nullptr) {
SimpleMemPool::instance().freeBuffer(name, device);
}

void dump() {
Expand Down
79 changes: 79 additions & 0 deletions src/kernels/rotary_embedding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,4 +386,83 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i
maxSupportedSeqLength, qkShape, positionIds);
}

#ifdef GPU
// For LLaMA
template <typename T>
static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) {
int dim = inv_freq_size * 2;
REQUIRES(dim == qkShape[3], "Incorrect shape, this dimention is not the head size.");

const int batchSize = qkShape[0];
const int seqLen = qkShape[1];
const int qHeads = qkShape[2];
const int kHeads = qkShape[4];
const int head_num = std::max(qHeads, kHeads);
const int head_size = qkShape[3];
const int half_head_size = (head_size + 1) / 2;
using namespace sycl;

auto rope_kernel
= [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, const int kHeads,
const int seq_size, const int head_size, const int half, T *query, T *key, int qStride,
int kStride, const sycl::accessor<int, 1, sycl::access::mode::read> &positionIds) {
size_t idx_bs_seq = item.get_global_id(0);
size_t idx_head_num = item.get_global_id(1);
size_t idx_half_head_dim = item.get_global_id(2);

size_t pos = positionIds[idx_bs_seq % seq_size];
float cos = embCos[pos * half + idx_half_head_dim];
float sin = embSin[pos * half + idx_half_head_dim];

T *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim;
T *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim;

if (idx_head_num < qHeads) {
auto q1 = q[0];
q[0] = q1 * cos - q[half] * sin;
q[half] = q[half] * cos + q1 * sin;
}
if (idx_head_num < kHeads) {
auto k1 = k[0];
k[0] = k1 * cos - k[half] * sin;
k[half] = k[half] * cos + k1 * sin;
}
};

// Reorder input
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
sycl::buffer<int, 1> positionIdsBuf(positionIds, sycl::range<1>(seqLen));
gpu_queue->submit([&](sycl::handler &cgh) {
sycl::accessor position(positionIdsBuf, cgh, sycl::read_only);
sycl::range<3> globalSize(batchSize * seqLen, head_num, half_head_size);
sycl::range<3> workGroupSize(1, 1, 1);

cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=](sycl::nd_item<3> item) {
rope_kernel(item, emb_cos, emb_sin, qHeads, kHeads, seqLen, head_size, half_head_size, query, key, qStride,
kStride, position);
});
});
gpu_queue->wait();
}

void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) {
llamaApplyRotaryPosEmbeding<float>(
device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *key, int qStride, int kStride,
float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) {
llamaApplyRotaryPosEmbeding<bfloat16_t>(
device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}

void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride,
float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) {
llamaApplyRotaryPosEmbeding<float16_t>(
device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds);
}
#endif

} // namespace xft
14 changes: 13 additions & 1 deletion src/kernels/rotary_embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void llamaApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStride
void llamaApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

// For continous batching
// For LLaMA continous batching
void llamaApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim,
int totSeqLen, int qHeads, int kHeads, const int *positionIds);

Expand Down Expand Up @@ -65,4 +65,16 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i
float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape,
const int *positionIds);

#ifdef GPU
// For LLaMA
void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos,
float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *key, int qStride, int kStride,
float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);

void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride,
float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds);
#endif

} // namespace xft
Loading