Skip to content

Commit e5ee14b

Browse files
committed
[Tools] Accelerate model loading.
1 parent fdcbd02 commit e5ee14b

File tree

19 files changed

+137
-237
lines changed

19 files changed

+137
-237
lines changed

include/dtype.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ enum class DataType {
3131
w8a8_int8,
3232
w8a8_int4,
3333
w8a8_nf4,
34+
unknown,
3435
};
3536
} // namespace xft

src/common/float16.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ inline void float16_t::cvt_float_to_float16(const float *src, float16_t *dst, in
131131
int remainder = size % kStep;
132132

133133
// Process blocks of 16 floats at a time
134+
#pragma omp parallel for
134135
for (int i = 0; i < blockSize; ++i) {
135136
// Load the input floats into a AVX512 register
136137
__m512 input_vector = _mm512_loadu_ps(src + i * kStep);
@@ -156,6 +157,7 @@ inline void float16_t::cvt_float16_to_float(const float16_t *src, float *dst, in
156157
int blockSize = size / kStep;
157158
int remainder = size % kStep;
158159

160+
#pragma omp parallel for
159161
for (int i = 0; i < blockSize; ++i) {
160162
__m256i input_vector = _mm256_maskz_loadu_epi16(0xffff, src + i * kStep);
161163
__m512 output_vector = _mm512_cvtph_ps(input_vector);
@@ -175,6 +177,7 @@ inline void float16_t::float_add_float16(const float *src1, const float16_t *src
175177
int blockSize = size / kStep;
176178
int remainder = size % kStep;
177179

180+
#pragma omp parallel for
178181
for (int i = 0; i < blockSize; ++i) {
179182
__m512 vec1 = _mm512_loadu_ps(src1 + i * kStep);
180183
__m256i _t = _mm256_maskz_loadu_epi16(0xffff, src2 + i * kStep);

src/layers/dist_linear.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class DistLinear {
4545
// | |
4646
// | | splitSize(N)
4747
// |_________________________________________|
48-
void setWeight(const float *w, const float *b) {
48+
void setWeight(const float *w, const float *b = nullptr) {
4949
this->splitSize = outputSize / splits;
5050
this->splitOffset = this->splitSize * splitIdx;
5151

@@ -111,5 +111,5 @@ class DistLinear {
111111
hpj::Vector<float> scaleWeight; // if weight is int8
112112
hpj::Vector<float> zeroWeight; // if weight is int8
113113
hpj::Vector<float> sumWeight; // if weight is int8
114-
float *bias;
114+
float *bias = nullptr;
115115
};

src/layers/layer_norm.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,44 @@
1717
#include <cstdlib>
1818
#include <cstring>
1919

20-
#include "layernorm_kernels.h"
2120
#include "layer_norm.h"
21+
#include "layernorm_kernels.h"
2222
#include "timeline.h"
2323

2424
namespace xft {
2525

2626
// Layer normalization: only support the norm along last dimension
2727
LayerNorm::LayerNorm() {
28-
weights = nullptr;
28+
gamma = nullptr;
29+
beta = nullptr;
2930
normSize = 0;
3031
}
3132

3233
LayerNorm::~LayerNorm() {
33-
if (weights) { free(weights); }
34+
if (gamma) { free(gamma); }
35+
if (beta) { free(beta); }
3436
}
3537

3638
void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) {
3739
this->normSize = cols;
38-
this->weights = (float *)aligned_alloc(64, 2 * cols * sizeof(float));
39-
memcpy(weights, gamma, cols * sizeof(float));
40-
memcpy(weights + cols, beta, cols * sizeof(float));
40+
this->gamma = (float *)aligned_alloc(64, cols * sizeof(float));
41+
this->beta = (float *)aligned_alloc(64, cols * sizeof(float));
42+
memcpy(this->gamma, gamma, cols * sizeof(float));
43+
memcpy(this->beta, beta, cols * sizeof(float));
44+
}
45+
46+
void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaPath, int cols) {
47+
this->normSize = cols;
48+
loadWeight(gammaPath, this->gamma, cols);
49+
if (betaPath != "") loadWeight(betaPath, this->beta, cols);
4150
}
4251

4352
// input and output are in shape of (rows, normSize)
4453
// TODO: column-wise parallel
4554
void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
4655
TimeLine t("LayerNorm.forward");
47-
const float *pgamma = weights;
48-
const float *pbeta = weights + normSize;
56+
const float *pgamma = gamma;
57+
const float *pbeta = beta;
4958
invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride);
5059
}
5160

src/layers/layer_norm.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
// ============================================================================
1515
#pragma once
1616

17+
#include <string>
18+
#include "weight_util.h"
19+
1720
namespace xft {
1821

1922
// Layer normalization: only support the norm along last dimension
@@ -23,6 +26,7 @@ class LayerNorm {
2326
~LayerNorm();
2427

2528
void setWeight(const float *gamma, const float *beta, int cols);
29+
void setWeight(const std::string &gammaPath, const std::string &betaPath, int cols);
2630

2731
// input and output are in shape of (rows, normSize)
2832
// TODO: column-wise parallel
@@ -31,8 +35,8 @@ class LayerNorm {
3135
private:
3236
int normSize;
3337

34-
// the weights contains gamma and beta concated together
35-
float *weights;
38+
float *gamma = nullptr;
39+
float *beta = nullptr;
3640
};
3741

3842
} // namespace xft

src/layers/rms_norm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ void RmsNorm::setWeight(const float *w, const float *, int cols) {
3838
memcpy(weight, w, cols * sizeof(float));
3939
}
4040

41+
void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int cols) {
42+
loadWeight(modelPath, weight, cols);
43+
}
44+
4145
// input and output are in shape of (rows, normSize)
4246
void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
4347
TimeLine t("RmsNorm.forward");

src/layers/rms_norm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "bfloat16.h"
18+
#include "weight_util.h"
1819

1920
namespace xft {
2021

@@ -25,6 +26,7 @@ class RmsNorm {
2526
~RmsNorm();
2627

2728
void setWeight(const float *w, const float *, int cols);
29+
void setWeight(const std::string &modelPath, const std::string &, int cols);
2830

2931
// Input and output are in shape of (rows, normSize)
3032
void forward(const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6);
@@ -41,7 +43,7 @@ class RmsNorm {
4143
int normSize;
4244

4345
// the scale weight
44-
float *weight;
46+
float *weight = nullptr;
4547
};
4648

4749
} // namespace xft

src/layers/token_embedding.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class TokenEmbedding {
3838
}
3939
}
4040

41+
void setWeights(const std::string &weightPath) {
42+
loadWeight(weightPath, embTable, vocabSize * hiddenSize);
43+
}
44+
4145
// tokenIds ia a 2-dimension array with batchSize rows, and seqLen cols
4246
template <typename OutT>
4347
void forward(int *tokenIds, OutT *output, int batchSize, int seqLen) {
@@ -57,5 +61,5 @@ class TokenEmbedding {
5761
int vocabSize;
5862
int hiddenSize;
5963

60-
T *embTable;
64+
T *embTable = nullptr;
6165
};

src/models/baichuan.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,12 @@ Baichuan<WeiT>::~Baichuan() {
3838

3939
template <typename WeiT>
4040
void Baichuan<WeiT>::setEmbeddingWeights(const std::string &modelPath) {
41-
int vocabSize = embedding->getVocabSize();
42-
int hiddenSize = embedding->getHiddenSize();
43-
44-
float *tokenEmb = (float *)malloc(vocabSize * hiddenSize * sizeof(float));
45-
46-
loadWeight(modelPath + "/model.wte.bin", tokenEmb, vocabSize * hiddenSize, this->getDataType());
47-
48-
embedding->setWeights(tokenEmb);
49-
50-
free(tokenEmb);
41+
embedding->setWeights(modelPath + "/model.wte.bin");
5142
}
5243

5344
template <typename WeiT>
5445
void Baichuan<WeiT>::setFinalLnWeight(const std::string &modelPath) {
55-
int hiddenSize = embedding->getHiddenSize();
56-
57-
float *gamma = (float *)malloc(hiddenSize * sizeof(float));
58-
59-
loadWeight(modelPath + "/model.final_layernorm.weight.bin", gamma, hiddenSize, this->getDataType());
60-
61-
finalLN.setWeight(gamma, nullptr, hiddenSize);
62-
63-
free(gamma);
46+
finalLN.setWeight(modelPath + "/model.final_layernorm.weight.bin", "", embedding->getHiddenSize());
6447
}
6548

6649
// Prepare attention_mask which is like:

src/models/chatglm.cpp

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,13 @@ ChatGLM<WeiT>::~ChatGLM() {
5050

5151
template <typename WeiT>
5252
void ChatGLM<WeiT>::setEmbeddingWeights(const std::string &modelPath) {
53-
int vocabSize = embedding->getVocabSize();
54-
int hiddenSize = embedding->getHiddenSize();
55-
56-
float *tokenEmb = (float *)malloc(vocabSize * hiddenSize * sizeof(float));
57-
58-
loadWeight(modelPath + "/model.wte.bin", tokenEmb, vocabSize * hiddenSize, this->getDataType());
59-
60-
embedding->setWeights(tokenEmb);
61-
62-
free(tokenEmb);
53+
embedding->setWeights(modelPath + "/model.wte.bin");
6354
}
6455

6556
template <typename WeiT>
6657
void ChatGLM<WeiT>::setFinalLnWeight(const std::string &modelPath) {
67-
int hiddenSize = embedding->getHiddenSize();
68-
69-
float *gamma = (float *)malloc(hiddenSize * sizeof(float));
70-
float *beta = (float *)malloc(hiddenSize * sizeof(float));
71-
72-
loadWeight(modelPath + "/model.final_layernorm.weight.bin", gamma, hiddenSize, this->getDataType());
73-
loadWeight(modelPath + "/model.final_layernorm.bias.bin", beta, hiddenSize, this->getDataType());
74-
75-
finalLN.setWeight(gamma, beta, hiddenSize);
76-
77-
free(gamma);
78-
free(beta);
58+
finalLN.setWeight(modelPath + "/model.final_layernorm.weight.bin", modelPath + "/model.final_layernorm.bias.bin",
59+
embedding->getHiddenSize());
7960
}
8061

8162
// Prepare attention_mask

0 commit comments

Comments
 (0)