Skip to content

Commit 956fc3c

Browse files
committed
[Tools] Accelerate model loading.
1 parent 421cd69 commit 956fc3c

File tree

22 files changed

+220
-259
lines changed

22 files changed

+220
-259
lines changed

include/dtype.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum DataType {
3131
w8a8_int8,
3232
w8a8_int4,
3333
w8a8_nf4,
34+
unknown,
3435
};
3536

3637
enum DeviceKind {

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ add_subdirectory(comm_helper)
2424

2525
add_library(xfastertransformer_static STATIC)
2626

27-
set(SRC_LIB_LIST "utils" "layers" "kernels" "models" "searchers")
27+
set(SRC_LIB_LIST "utils" "layers" "kernels" "models" "searchers" "stdc++fs")
2828

2929
target_link_libraries(xfastertransformer_static
3030
${SRC_LIB_LIST}

src/common/float16.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class float16_t {
4242
operator float() const;
4343

4444
static void cvt_float_to_float16(const float *src, float16_t *dst, int size);
45+
static void cvt_float_to_float16_MT(const float *src, float16_t *dst, int size);
4546
static void cvt_float16_to_float(const float16_t *src, float *dst, int size);
47+
static void cvt_float16_to_float_MT(const float16_t *src, float *dst, int size);
4648
static void float_add_float16(const float *src1, const float16_t *src2, float *dst, int size);
4749

4850
private:
@@ -150,6 +152,36 @@ inline void float16_t::cvt_float_to_float16(const float *src, float16_t *dst, in
150152
}
151153
}
152154

155+
inline void float16_t::cvt_float_to_float16_MT(const float *src, float16_t *dst, int size) {
156+
// Round to nearest even mode
157+
constexpr int rounding_mode = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC;
158+
159+
// Process 16 floats (AVX512 is a 512-bit SIMD register)
160+
constexpr int kStep = 16;
161+
int blockSize = size / kStep;
162+
int remainder = size % kStep;
163+
164+
// Process blocks of 16 floats at a time
165+
#pragma omp parallel for
166+
for (int i = 0; i < blockSize; ++i) {
167+
// Load the input floats into a AVX512 register
168+
__m512 input_vector = _mm512_loadu_ps(src + i * kStep);
169+
170+
// Convert the floats to float16_t using AVX512 intrinsics
171+
__m256i output_vector = _mm512_cvtps_ph(input_vector, rounding_mode);
172+
173+
// Store the converted values in the output array
174+
_mm256_mask_storeu_epi16(dst + i * kStep, 0xffff, output_vector);
175+
}
176+
177+
if (remainder != 0) {
178+
__mmask16 mask = 0xFFFF >> (kStep - remainder);
179+
__m512 input_vector = _mm512_maskz_loadu_ps(mask, src + size - remainder);
180+
__m256i output_vector = _mm512_cvtps_ph(input_vector, rounding_mode);
181+
_mm256_mask_storeu_epi16(dst + size - remainder, mask, output_vector);
182+
}
183+
}
184+
153185
inline void float16_t::cvt_float16_to_float(const float16_t *src, float *dst, int size) {
154186
// Process 16 floats (AVX512 is a 512-bit SIMD register)
155187
constexpr int kStep = 16;
@@ -170,6 +202,27 @@ inline void float16_t::cvt_float16_to_float(const float16_t *src, float *dst, in
170202
}
171203
}
172204

205+
inline void float16_t::cvt_float16_to_float_MT(const float16_t *src, float *dst, int size) {
206+
// Process 16 floats (AVX512 is a 512-bit SIMD register)
207+
constexpr int kStep = 16;
208+
int blockSize = size / kStep;
209+
int remainder = size % kStep;
210+
211+
#pragma omp parallel for
212+
for (int i = 0; i < blockSize; ++i) {
213+
__m256i input_vector = _mm256_maskz_loadu_epi16(0xffff, src + i * kStep);
214+
__m512 output_vector = _mm512_cvtph_ps(input_vector);
215+
_mm512_storeu_ps(dst + i * kStep, output_vector);
216+
}
217+
218+
if (remainder != 0) {
219+
__mmask16 mask = 0xFFFF >> (kStep - remainder);
220+
__m256i input_vector = _mm256_maskz_loadu_epi16(mask, src + size - remainder);
221+
__m512 output_vector = _mm512_cvtph_ps(input_vector);
222+
_mm512_mask_storeu_ps(dst + size - remainder, mask, output_vector);
223+
}
224+
}
225+
173226
inline void float16_t::float_add_float16(const float *src1, const float16_t *src2, float *dst, int size) {
174227
constexpr int kStep = 16;
175228
int blockSize = size / kStep;

src/layers/attention.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -812,13 +812,8 @@ class Attention {
812812
auto srcV = value + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
813813
auto dstV = presentValue.getSequence(pastSeqLen + seq, b, i);
814814

815-
if constexpr (std::is_same_v<KVCacheT, float>) {
816-
memcpy(dstK, srcK, headSize * sizeof(float));
817-
memcpy(dstV, srcV, headSize * sizeof(float));
818-
} else if constexpr (std::is_same_v<KVCacheT, float16_t>) {
819-
float16_t::cvt_float_to_float16(srcK, dstK, headSize);
820-
float16_t::cvt_float_to_float16(srcV, dstV, headSize);
821-
}
815+
xft::copy(dstK, srcK, headSize);
816+
xft::copy(srcV, dstV, headSize);
822817
}
823818
}
824819
}

src/layers/dist_linear.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class DistLinear {
4545
// | |
4646
// | | splitSize(N)
4747
// |_________________________________________|
48-
void setWeight(DecoderContext *ctx, const float *w, const float *b) {
48+
void setWeight(DecoderContext *ctx, const float *w, const float *b = nullptr) {
4949
this->splitSize = outputSize / splits;
5050
this->splitOffset = this->splitSize * splitIdx;
5151

@@ -111,5 +111,5 @@ class DistLinear {
111111
hpj::Vector<float> scaleWeight; // if weight is int8
112112
hpj::Vector<float> zeroWeight; // if weight is int8
113113
hpj::Vector<float> sumWeight; // if weight is int8
114-
float *bias;
114+
float *bias = nullptr;
115115
};

src/layers/layer_norm.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,44 @@
1717
#include <cstdlib>
1818
#include <cstring>
1919

20-
#include "layernorm_kernels.h"
2120
#include "layer_norm.h"
21+
#include "layernorm_kernels.h"
2222
#include "timeline.h"
2323

2424
namespace xft {
2525

2626
// Layer normalization: only support the norm along last dimension
2727
LayerNorm::LayerNorm() {
28-
weights = nullptr;
28+
gamma = nullptr;
29+
beta = nullptr;
2930
normSize = 0;
3031
}
3132

3233
LayerNorm::~LayerNorm() {
33-
if (weights) { free(weights); }
34+
if (gamma) { free(gamma); }
35+
if (beta) { free(beta); }
3436
}
3537

3638
void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) {
3739
this->normSize = cols;
38-
this->weights = (float *)aligned_alloc(64, 2 * cols * sizeof(float));
39-
memcpy(weights, gamma, cols * sizeof(float));
40-
memcpy(weights + cols, beta, cols * sizeof(float));
40+
this->gamma = (float *)aligned_alloc(64, cols * sizeof(float));
41+
this->beta = (float *)aligned_alloc(64, cols * sizeof(float));
42+
memcpy(this->gamma, gamma, cols * sizeof(float));
43+
memcpy(this->beta, beta, cols * sizeof(float));
44+
}
45+
46+
void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaPath, int cols) {
47+
this->normSize = cols;
48+
loadWeight(gammaPath, this->gamma, cols);
49+
if (betaPath != "") loadWeight(betaPath, this->beta, cols);
4150
}
4251

4352
// input and output are in shape of (rows, normSize)
4453
// TODO: column-wise parallel
4554
void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
4655
TimeLine t("LayerNorm.forward");
47-
const float *pgamma = weights;
48-
const float *pbeta = weights + normSize;
56+
const float *pgamma = gamma;
57+
const float *pbeta = beta;
4958
invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride);
5059
}
5160

src/layers/layer_norm.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
// ============================================================================
1515
#pragma once
1616

17+
#include <string>
18+
#include "weight_util.h"
19+
1720
namespace xft {
1821

1922
// Layer normalization: only support the norm along last dimension
@@ -23,6 +26,7 @@ class LayerNorm {
2326
~LayerNorm();
2427

2528
void setWeight(const float *gamma, const float *beta, int cols);
29+
void setWeight(const std::string &gammaPath, const std::string &betaPath, int cols);
2630

2731
// input and output are in shape of (rows, normSize)
2832
// TODO: column-wise parallel
@@ -31,8 +35,8 @@ class LayerNorm {
3135
private:
3236
int normSize;
3337

34-
// the weights contains gamma and beta concated together
35-
float *weights;
38+
float *gamma = nullptr;
39+
float *beta = nullptr;
3640
};
3741

3842
} // namespace xft

src/layers/rms_norm.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ void RmsNorm::setWeight(const float *w, const float *, int cols) {
3838
memcpy(weight, w, cols * sizeof(float));
3939
}
4040

41+
void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int cols) {
42+
this->normSize = cols;
43+
loadWeight(modelPath, weight, cols);
44+
}
45+
4146
// input and output are in shape of (rows, normSize)
4247
void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
4348
TimeLine t("RmsNorm.forward");

src/layers/rms_norm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "bfloat16.h"
18+
#include "weight_util.h"
1819

1920
namespace xft {
2021

@@ -25,6 +26,7 @@ class RmsNorm {
2526
~RmsNorm();
2627

2728
void setWeight(const float *w, const float *, int cols);
29+
void setWeight(const std::string &modelPath, const std::string &, int cols);
2830

2931
// Input and output are in shape of (rows, normSize)
3032
void forward(const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6);
@@ -41,7 +43,7 @@ class RmsNorm {
4143
int normSize;
4244

4345
// the scale weight
44-
float *weight;
46+
float *weight = nullptr;
4547
};
4648

4749
} // namespace xft

src/layers/token_embedding.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class TokenEmbedding {
3838
}
3939
}
4040

41+
void setWeights(const std::string &weightPath) {
42+
loadWeight(weightPath, embTable, vocabSize * hiddenSize);
43+
}
44+
4145
// tokenIds ia a 2-dimension array with batchSize rows, and seqLen cols
4246
template <typename OutT>
4347
void forward(int *tokenIds, OutT *output, int batchSize, int seqLen) {
@@ -57,5 +61,5 @@ class TokenEmbedding {
5761
int vocabSize;
5862
int hiddenSize;
5963

60-
T *embTable;
64+
T *embTable = nullptr;
6165
};

0 commit comments

Comments
 (0)