Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantzation AWQ GEMM + GEMV #1727

Merged
merged 6 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ set(SOURCES
src/ops/transpose.cc
src/ops/nccl_ops.cc
src/ops/nccl_ops_cpu.cc
src/ops/awq/dequantize.cc
src/ops/awq/dequantize_cpu.cc
src/ops/awq/gemm.cc
src/ops/awq/gemm_cpu.cc
src/ops/awq/gemv.cc
src/ops/awq/gemv_cpu.cc
src/ops/sum.cc
src/padder.cc
src/profiler.cc
src/random.cc
Expand Down Expand Up @@ -595,6 +602,9 @@ if (WITH_CUDA)
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
src/ops/awq/gemm_gpu.cu
src/ops/awq/gemv_gpu.cu
src/ops/awq/dequantize_gpu.cu
)

set_source_files_properties(
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The project is production-oriented and comes with [backward compatibility guaran
## Key features

* **Fast and efficient execution on CPU and GPU**<br/>The execution [is significantly faster and requires less resources](#benchmarks) than general-purpose deep learning frameworks on supported models and tasks thanks to many advanced optimizations: layer fusion, padding removal, batch reordering, in-place operations, caching mechanism, etc.
* **Quantization and reduced precision**<br/>The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), and 8-bit integers (INT8).
* **Quantization and reduced precision**<br/>The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), 8-bit integers (INT8) and AWQ quantization (INT4).
* **Multiple CPU architectures support**<br/>The project supports x86-64 and AArch64/ARM64 processors and integrates multiple backends that are optimized for these platforms: [Intel MKL](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html), [oneDNN](https://github.com/oneapi-src/oneDNN), [OpenBLAS](https://www.openblas.net/), [Ruy](https://github.com/google/ruy), and [Apple Accelerate](https://developer.apple.com/documentation/accelerate).
* **Automatic CPU detection and code dispatch**<br/>One binary can include multiple backends (e.g. Intel MKL and oneDNN) and instruction set architectures (e.g. AVX, AVX2) that are automatically selected at runtime based on the CPU information.
* **Parallel and asynchronous execution**<br/>Multiple batches can be processed in parallel and asynchronously using multiple GPUs or CPU cores.
Expand Down
19 changes: 19 additions & 0 deletions docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Quantization is a technique that can reduce the model size and accelerate its ex
* 16-bit integers (INT16)
* 16-bit floating points (FP16)
* 16-bit brain floating points (BF16)
* 4-bit AWQ Quantization

```{tip}
See the benchmark results in the main [README](https://github.com/OpenNMT/CTranslate2#benchmarks) to compare the performance and memory usage with and without quantization.
Expand Down Expand Up @@ -161,3 +162,21 @@ In this mode, all model weights are stored in half precision and all layers are
* NVIDIA GPU with Compute Capability >= 8.0

In this mode, all model weights are stored in BF16 and all layers are run with this type.

### 4-bit AWQ

The compute type would be `int32_float16`

**Supported on:**

* NVIDIA GPU with Compute Capability >= 7.5

In this mode, all model weights are stored in half precision and all layers are run in half precision. Other parameters like scale and zero are stored in ``int32``.

For example,

```bash
ct2-transformers-converter --model TheBloke/Llama-2-7B-AWQ --copy_files tokenizer.model --output_dir ct2_model
```

We have to quantize the model with AWQ first, then convert it to CT2 format.
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,19 @@ namespace ctranslate2 {
const StorageView& _weight;
const StorageView* _bias;
const StorageView* _qscale;
const StorageView* _qzero;
const StorageView* _u8_shift_compensation;
StorageView _partial_weight;
StorageView _partial_bias;
StorageView _partial_qscale;
StorageView _partial_u8_shift_compensation;
const DataType _output_type;
const models::QUANTIZATION_TYPE _quant_method;
const bool _quantized_gemm;
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
const ops::ActivationType* _activation_type;
const bool _is_layer_out;
};

Expand Down
17 changes: 16 additions & 1 deletion include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
namespace ctranslate2 {
namespace models {

enum class QUANTIZATION_TYPE {
CT2,
AWQ_GEMM,
AWQ_GEMV
};

static const size_t current_binary_version = 6;

// Checks whether the provided path could contain a CTranslate2 model.
Expand Down Expand Up @@ -90,6 +96,14 @@ namespace ctranslate2 {
return _use_flash_attention;
}

QUANTIZATION_TYPE quant_method() const {
return _quant_method;
}

void set_quant_method(QUANTIZATION_TYPE type) {
_quant_method = type;
}

virtual bool use_global_int16_scale() const {
return true;
}
Expand Down Expand Up @@ -160,7 +174,7 @@ namespace ctranslate2 {

private:
void process_linear_weights();
void set_compute_type(ComputeType type, Device device, int device_index);
void set_compute_type(ComputeType type, Device device, int device_index, bool update_weight=true);
void ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype);
Expand All @@ -177,6 +191,7 @@ namespace ctranslate2 {
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
bool _use_flash_attention = false;
bool _tensor_parallel = false;
QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2;
};

template<>
Expand Down
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/awq/dequantize_awq.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "../op.h"

namespace ctranslate2 {
namespace ops {

class DequantizeAwq : public Op {
public:
DequantizeAwq();

void operator()(const StorageView& input,
const StorageView& scale,
const StorageView& zeros,
StorageView& output) const;

private:
template <Device D, typename InT, typename OutT>
void dequantize(const StorageView& input,
const StorageView& scale,
const StorageView& zeros,
StorageView& output) const;
};

}
}
27 changes: 27 additions & 0 deletions include/ctranslate2/ops/awq/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "../activation.h"
#include "../gemm.h"

namespace ctranslate2 {
namespace ops {
class GemmAwq : public Gemm {
public:
using Gemm::Gemm;
void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c,
const StorageView* bias = nullptr) const;

private:
template <Device D, typename In, typename Out>
void compute(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
};
}
}
33 changes: 33 additions & 0 deletions include/ctranslate2/ops/awq/gemv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "../activation.h"
#include "../gemm.h"

namespace ctranslate2 {
namespace ops {
class GemvAwq : public Gemm {
public:
using Gemm::Gemm;
void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c,
const StorageView* bias = nullptr) const;

private:
template <Device D, typename In, typename Out>
void compute_gemv(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
template <Device D, typename In, typename Out>
void compute_gemv2(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
};
}
}
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ namespace ctranslate2 {
const dim_t k,
const dim_t n,
const float alpha);
protected:
const ActivationType* _activation_type;

private:
float _alpha;
Expand All @@ -47,7 +49,6 @@ namespace ctranslate2 {
bool _trans_b;
bool _a_is_packed;
bool _b_is_packed;
const ActivationType* _activation_type;

template <Device D, typename In, typename Out>
void compute(const StorageView& a,
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/mean.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ namespace ctranslate2 {

void operator()(const StorageView& input, StorageView& output) const override;

private:
protected:
template <Device D, typename T>
void compute(const StorageView& input,
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
const bool get_sum,
StorageView& output) const;

const dim_t _axis;
Expand Down
4 changes: 4 additions & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@
#include "slide.h"
#include "nccl_ops.h"
#include "flash_attention.h"
#include "awq/gemm.h"
#include "awq/gemv.h"
#include "awq/dequantize_awq.h"
#include "sum.h"
17 changes: 17 additions & 0 deletions include/ctranslate2/ops/sum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "op.h"
#include "mean.h"

namespace ctranslate2 {
namespace ops {

class Sum : public Mean {
public:
Sum(const dim_t axis);

void operator()(const StorageView& input, StorageView& output) const override;
};

}
}
Loading
Loading