Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Attention {
}

// The inerface is for PyTorch, thus the weights are already transposed
// OriWeiT: float or int8_t
// OriWeiT: float, int8_t or uint4x2_t
template <typename OriWeiT>
void setWeights(DecoderContext *ctx, const OriWeiT *queryWeight, const float *queryScale, const float *queryZero,
const float *queryBias, const OriWeiT *keyWeight, const float *keyScale, const float *keyZero,
Expand All @@ -87,31 +87,37 @@ class Attention {
int responsibleCols = qResponsibleCols + 2 * kvResponsibleCols;
qkvWeight.Resize(hiddenSize, responsibleCols);

OriWeiT *concatBuf = (OriWeiT *)malloc(hiddenSize * responsibleCols * sizeof(OriWeiT));
constexpr int sizeFactor = std::is_same_v<OriWeiT, uint4x2_t> ? 2 : 1;

OriWeiT *concatBuf = (OriWeiT *)malloc(hiddenSize * responsibleCols * sizeof(OriWeiT) / sizeFactor);
if (trans) {
memcpy(concatBuf, queryWeight + this->startQHead * headSize * hiddenSize,
hiddenSize * qResponsibleCols * sizeof(OriWeiT));
memcpy(concatBuf + hiddenSize * qResponsibleCols, keyWeight + this->startKVHead * headSize * hiddenSize,
hiddenSize * kvResponsibleCols * sizeof(OriWeiT));
memcpy(concatBuf + hiddenSize * (qResponsibleCols + kvResponsibleCols),
valueWeight + this->startKVHead * headSize * hiddenSize,
hiddenSize * kvResponsibleCols * sizeof(OriWeiT));
memcpy(concatBuf, queryWeight + this->startQHead * headSize * hiddenSize / sizeFactor,
hiddenSize * qResponsibleCols * sizeof(OriWeiT) / sizeFactor);
memcpy(concatBuf + hiddenSize * qResponsibleCols / sizeFactor,
keyWeight + this->startKVHead * headSize * hiddenSize / sizeFactor,
hiddenSize * kvResponsibleCols * sizeof(OriWeiT) / sizeFactor);
memcpy(concatBuf + hiddenSize * (qResponsibleCols + kvResponsibleCols) / sizeFactor,
valueWeight + this->startKVHead * headSize * hiddenSize / sizeFactor,
hiddenSize * kvResponsibleCols * sizeof(OriWeiT) / sizeFactor);
} else {
int qkvStride = (ctx->attHeadNum + ctx->kvHeadNum + ctx->kvHeadNum) * ctx->attHeadSize;
#pragma omp parallel for
for (int i = 0; i < hiddenSize; ++i) {
memcpy(concatBuf + i * responsibleCols, queryWeight + i * qkvStride + this->startQHead * headSize,
qResponsibleCols * sizeof(OriWeiT));
memcpy(concatBuf + i * responsibleCols + qResponsibleCols,
keyWeight + i * qkvStride + this->startKVHead * headSize, kvResponsibleCols * sizeof(OriWeiT));
memcpy(concatBuf + i * responsibleCols + qResponsibleCols + kvResponsibleCols,
valueWeight + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(OriWeiT));
memcpy(concatBuf + i * responsibleCols / sizeFactor,
queryWeight + i * qkvStride / sizeFactor + this->startQHead * headSize / sizeFactor,
qResponsibleCols * sizeof(OriWeiT) / sizeFactor);
memcpy(concatBuf + i * responsibleCols / sizeFactor + qResponsibleCols / sizeFactor,
keyWeight + i * qkvStride / sizeFactor + this->startKVHead * headSize / sizeFactor,
kvResponsibleCols * sizeof(OriWeiT) / sizeFactor);
memcpy(concatBuf + i * responsibleCols / sizeFactor + qResponsibleCols / sizeFactor
+ kvResponsibleCols / sizeFactor,
valueWeight + i * qkvStride / sizeFactor + this->startKVHead * headSize / sizeFactor,
kvResponsibleCols * sizeof(OriWeiT) / sizeFactor);
}
}
float *concatScale = nullptr;
float *concatZero = nullptr;
if constexpr (std::is_same_v<OriWeiT, int8_t>) {
if constexpr (std::is_same_v<OriWeiT, int8_t> || std::is_same_v<OriWeiT, uint4x2_t>) {
concatScale = (float *)malloc(responsibleCols * sizeof(float));
concatZero = (float *)malloc(responsibleCols * sizeof(float));
memcpy(concatScale, queryScale + this->startQHead * headSize, qResponsibleCols * sizeof(float));
Expand Down
2 changes: 1 addition & 1 deletion src/layers/decoder_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Decoder {

int getLayerId() { return layerIdx; }

// OriWeiT: float or int8_t
// OriWeiT: float, int8_t or uint4x2_t
template <typename OriWeiT>
void setWeights(DecoderContext *ctx, const OriWeiT *queryWeight, const float *queryScale, const float *queryZero,
const float *queryBias, const OriWeiT *keyWeight, const float *keyScale, const float *keyZero,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/mlp_llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {

LlamaMLP(DecoderContext *ctx) {}

// OriWeiT: float or int8_t
// OriWeiT: float, int8_t or uint4x2_t
template <typename OriWeiT>
void setWeights(DecoderContext *ctx, const OriWeiT *gateW, const float *gateS, const float *gateZ,
const float * /*unused*/, const OriWeiT *upW, const float *upS, const float *upZ, const float * /*unused*/,
Expand Down
40 changes: 23 additions & 17 deletions src/models/common_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ class CommonDecoder : public AbstractDecoder {

// DataType dt = getWeightType(configPath, modelType);
DataType dt = DataType::fp32;
if (quantQweightDataType == "int8") {
dt = DataType::int8;
if (quantQweightDataType == "int8" || quantQweightDataType == "uint4") {
dt = quantQweightDataType == "int8" ? DataType::int8 : DataType::int4;
REQUIRES(quantScalesDataType == "fp32", "scales should be fp32 data type.");
REQUIRES(quantZerosDataType == "fp32", "zeros should be fp32 data type.");
REQUIRES(quantGroupsize == -1, "Quantization with groupsize is not supported.");
Expand Down Expand Up @@ -240,6 +240,8 @@ class CommonDecoder : public AbstractDecoder {
auto pdec = new DECODER(ctx, i);
if (dt == DataType::int8) {
this->setDecoderWeights<int8_t>(pdec, modelPath, i);
} else if (dt == DataType::int4) {
this->setDecoderWeights<uint4x2_t>(pdec, modelPath, i);
} else if (dt == DataType::fp32) {
this->setDecoderWeights<float>(pdec, modelPath, i);
}
Expand Down Expand Up @@ -628,7 +630,7 @@ class CommonDecoder : public AbstractDecoder {
return this->context.get();
}

// OriWeiT: float or int8_t
// OriWeiT: float, int8_t or uint4x2_t
template <typename OriWeiT>
void setDecoderWeights(DECODER *pdecoder, const std::string &modelPath, int layerIdx) {
const int hiddenSize = getContext()->hiddenSize;
Expand Down Expand Up @@ -670,8 +672,10 @@ class CommonDecoder : public AbstractDecoder {
float *fc3Scales = nullptr;
float *fc3Zeros = nullptr;

// INT8 quant, wbits = 8, qweight dtype: int8
if constexpr (std::is_same_v<OriWeiT, int8_t>) {
// INT8/INT4 quant, wbits = 8/4, qweight dtype: int8_t/uint4x2_t
if constexpr (std::is_same_v<OriWeiT, int8_t> || std::is_same_v<OriWeiT, uint4x2_t>) {
DataType dt = std::is_same_v<OriWeiT, int8_t> ? DataType::int8 : DataType::int4;

qkvZeros = (float *)ALLOC(qkvSize * sizeof(float), 64);
qkvScales = (float *)ALLOC(qkvSize * sizeof(float), 64);
attnOutZeros = (float *)ALLOC(hiddenSize * sizeof(float), 64);
Expand All @@ -683,7 +687,7 @@ class CommonDecoder : public AbstractDecoder {

loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx)
+ ".attention.query_key_value.qweight.0.bin",
qkvWeight, hiddenSize * qkvSize, DataType::int8);
qkvWeight, hiddenSize * qkvSize, dt);
loadWeight(
modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.query_key_value.zeros.0.bin",
qkvZeros, qkvSize, DataType::fp32);
Expand All @@ -692,7 +696,7 @@ class CommonDecoder : public AbstractDecoder {
qkvScales, qkvSize, DataType::fp32);

loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.qweight.0.bin",
attnOutWeight, hiddenSize * hiddenSize, DataType::int8);
attnOutWeight, hiddenSize * hiddenSize, dt);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.zeros.0.bin",
attnOutZeros, hiddenSize, DataType::fp32);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".attention.dense.scales.0.bin",
Expand All @@ -702,14 +706,14 @@ class CommonDecoder : public AbstractDecoder {
if (fileExists(
modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin")) {
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.qweight.0.bin",
fc1Weight, hiddenSize * imSize * mlpFactor, DataType::int8);
fc1Weight, hiddenSize * imSize * mlpFactor, dt);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.zeros.0.bin",
fc1Zeros, imSize * mlpFactor, DataType::fp32);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_h_to_4h.scales.0.bin",
fc1Scales, imSize * mlpFactor, DataType::fp32);

loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.qweight.0.bin",
fc2Weight, hiddenSize * imSize, DataType::int8);
fc2Weight, hiddenSize * imSize, dt);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.zeros.0.bin",
fc2Zeros, hiddenSize, DataType::fp32);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.scales.0.bin",
Expand All @@ -722,21 +726,21 @@ class CommonDecoder : public AbstractDecoder {
fc3Scales = (float *)ALLOC(hiddenSize * sizeof(float), 64);

loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.qweight.0.bin",
fc1Weight, hiddenSize * imSize * mlpFactor, DataType::int8);
fc1Weight, hiddenSize * imSize * mlpFactor, dt);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.zeros.0.bin",
fc1Zeros, imSize * mlpFactor, DataType::fp32);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.gate_proj.scales.0.bin",
fc1Scales, imSize * mlpFactor, DataType::fp32);

loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.qweight.0.bin",
fc2Weight, hiddenSize * imSize, DataType::int8);
fc2Weight, hiddenSize * imSize, dt);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.zeros.0.bin",
fc2Zeros, imSize, DataType::fp32);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.up_proj.scales.0.bin",
fc2Scales, imSize, DataType::fp32);

loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.qweight.0.bin",
fc3Weight, hiddenSize * imSize, DataType::int8);
fc3Weight, hiddenSize * imSize, dt);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.zeros.0.bin",
fc3Zeros, hiddenSize, DataType::fp32);
loadWeight(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.down_proj.scales.0.bin",
Expand Down Expand Up @@ -803,11 +807,13 @@ class CommonDecoder : public AbstractDecoder {
READ_OPTIONAL(modelPath + "/model.layers." + std::to_string(layerIdx) + ".mlp.dense_4h_to_h.bias.bin", fc2Bias,
hiddenSize, "read FC2 bias error");

pdecoder->setWeights(getContext(), qkvWeight, qkvScales, qkvZeros, qkvBias, qkvWeight + qSize,
qkvScales + qSize, qkvZeros + qSize, qkvBias + qSize, qkvWeight + qSize + kvSize,
qkvScales + qSize + kvSize, qkvZeros + qSize + kvSize, qkvBias + qSize + kvSize, attnOutWeight,
attnOutScales, attnOutZeros, attnOutBias, ln1Gamma, ln1Beta, fc1Weight, fc1Scales, fc1Zeros, fc1Bias,
fc2Weight, fc2Scales, fc2Zeros, fc2Bias, ln2Gamma, ln2Beta, fc3Weight, fc3Scales, fc3Zeros, false);
constexpr int sizeFactor = std::is_same_v<OriWeiT, uint4x2_t> ? 2 : 1;
pdecoder->setWeights(getContext(), qkvWeight, qkvScales, qkvZeros, qkvBias, qkvWeight + qSize / sizeFactor,
qkvScales + qSize, qkvZeros + qSize, qkvBias + qSize,
qkvWeight + qSize / sizeFactor + kvSize / sizeFactor, qkvScales + qSize + kvSize,
qkvZeros + qSize + kvSize, qkvBias + qSize + kvSize, attnOutWeight, attnOutScales, attnOutZeros,
attnOutBias, ln1Gamma, ln1Beta, fc1Weight, fc1Scales, fc1Zeros, fc1Bias, fc2Weight, fc2Scales, fc2Zeros,
fc2Bias, ln2Gamma, ln2Beta, fc3Weight, fc3Scales, fc3Zeros, false);

free(qkvWeight);
free(attnOutWeight);
Expand Down
16 changes: 16 additions & 0 deletions src/utils/matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,22 @@ class MMHelper {
}
}

// UINT4 -> UINT4
else if constexpr (std::is_same_v<OriWeiT, uint4x2_t> && std::is_same_v<WeiT, uint4x2_t>) {
int size = trans ? rowSize : colSize;
int offset = trans ? rowOffset : colOffset;
scaleWeight.Resize(size);
zeroWeight.Resize(size);
memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float));
memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float));
#pragma omp parallel for
for (uint64_t i = 0; i < rowSize; i++) {
WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride() / 2;
const OriWeiT *src = weight + (rowOffset + i) * cols / 2 + colOffset / 2;
memcpy(dst, src, colSize * sizeof(WeiT) / 2);
}
}

else {
printf("%s:%d: Do not support this kind of weights datatype convertion.\n", __FILE__, __LINE__);
exit(-1);
Expand Down
4 changes: 4 additions & 0 deletions src/utils/weight_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,16 @@ int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataTy
std::filesystem::path folderPath = pathObj.parent_path();
w_type = getWeightType(folderPath.append("config.ini").string());
}
//1 uint4x2 stores 2 uint4 value, so load size is halfed.
if constexpr (std::is_same_v<T, uint4x2_t>) { size = size / 2; }
if (!ptr) { ptr = (T *)xft::alloc(size * sizeof(T)); }
int file_size = 0;
switch (w_type) {
case DataType::fp32: file_size = loadWeightWithConvert<T, float>(ptr, size, filename, required); break;
case DataType::fp16: file_size = loadWeightWithConvert<T, float16_t>(ptr, size, filename, required); break;
case DataType::bf16: file_size = loadWeightWithConvert<T, bfloat16_t>(ptr, size, filename, required); break;
case DataType::int8: file_size = loadWeightWithConvert<T, int8_t>(ptr, size, filename, required); break;
case DataType::int4: file_size = loadWeightWithConvert<T, uint4x2_t>(ptr, size, filename, required); break;
default: printf("Not support loading %s with DataType=%d", filename.c_str(), w_type);
}
return file_size;
Expand All @@ -188,4 +191,5 @@ template int loadWeightWithConvert<uint4x2_t, float16_t>(uint4x2_t *, int, const
template int loadWeightWithConvert<nf4x2_t, float16_t>(nf4x2_t *, int, const std::string &, bool);

template int loadWeightWithConvert<int8_t, int8_t>(int8_t *, int, const std::string &, bool);
template int loadWeightWithConvert<uint4x2_t, uint4x2_t>(uint4x2_t *, int, const std::string &, bool);
} // namespace xft