|
17 | 17 | #include <cstdlib> |
18 | 18 | #include <cstring> |
19 | 19 |
|
20 | | -#include "layernorm_kernels.h" |
21 | 20 | #include "layer_norm.h" |
| 21 | +#include "layernorm_kernels.h" |
22 | 22 | #include "timeline.h" |
23 | 23 |
|
24 | 24 | namespace xft { |
25 | 25 |
|
26 | 26 | // Layer normalization: only support the norm along last dimension |
27 | 27 | LayerNorm::LayerNorm() { |
28 | | - weights = nullptr; |
| 28 | + gamma = nullptr; |
| 29 | + beta = nullptr; |
29 | 30 | normSize = 0; |
30 | 31 | } |
31 | 32 |
|
32 | 33 | LayerNorm::~LayerNorm() { |
33 | | - if (weights) { free(weights); } |
| 34 | + if (gamma) { free(gamma); } |
| 35 | + if (beta) { free(beta); } |
34 | 36 | } |
35 | 37 |
|
36 | 38 | void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) { |
37 | 39 | this->normSize = cols; |
38 | | - this->weights = (float *)aligned_alloc(64, 2 * cols * sizeof(float)); |
39 | | - memcpy(weights, gamma, cols * sizeof(float)); |
40 | | - memcpy(weights + cols, beta, cols * sizeof(float)); |
| 40 | + this->gamma = (float *)aligned_alloc(64, cols * sizeof(float)); |
| 41 | + this->beta = (float *)aligned_alloc(64, cols * sizeof(float)); |
| 42 | + memcpy(this->gamma, gamma, cols * sizeof(float)); |
| 43 | + memcpy(this->beta, beta, cols * sizeof(float)); |
| 44 | +} |
| 45 | + |
| 46 | +void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaPath, int cols) { |
| 47 | + this->normSize = cols; |
| 48 | + loadWeight(gammaPath, this->gamma, cols); |
| 49 | + if (betaPath != "") loadWeight(betaPath, this->beta, cols); |
41 | 50 | } |
42 | 51 |
|
43 | 52 | // input and output are in shape of (rows, normSize) |
44 | 53 | // TODO: column-wise parallel |
45 | 54 | void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { |
46 | 55 | TimeLine t("LayerNorm.forward"); |
47 | | - const float *pgamma = weights; |
48 | | - const float *pbeta = weights + normSize; |
| 56 | + const float *pgamma = gamma; |
| 57 | + const float *pbeta = beta; |
49 | 58 | invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride); |
50 | 59 | } |
51 | 60 |
|
|
0 commit comments