forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 4
/
int8_gemm.h
127 lines (109 loc) · 3.25 KB
/
int8_gemm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*
gemm methods are adapted from ft
*/
#include <ATen/cuda/CUDAContext.h>
#include "cublasAlgoMap.h"
#include "cublasINT8MMWrapper.h"
class I8CUGEMM {
private:
cublasINT8MMWrapper *int8_gemm_wrapper = nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
public:
I8CUGEMM();
~I8CUGEMM();
void linear_a8_w8_o32(
torch::Tensor& input,
torch::Tensor& weight,
torch::Tensor& output);
void linear_a8_w8_o32_(
torch::Tensor& input,
torch::Tensor& weight,
torch::Tensor& output);
void linear_a8_w8_o8(
torch::Tensor& input,
torch::Tensor& weight,
torch::Tensor& output,
float alpha);
void linear_a8_w8_o8_(
torch::Tensor& input,
torch::Tensor& weight,
torch::Tensor& output,
float alpha);
};
I8CUGEMM::I8CUGEMM() {
// cublasAlgoMap *cublas_algo_map = new cublasAlgoMap("igemm_config.in");
cublasAlgoMap *cublas_algo_map = new cublasAlgoMap();
std::mutex *cublas_wrapper_mutex = new std::mutex();
bool use_ORDER_COL32_2R_4R4 = true;
cublasLtHandle_t cublaslt_handle;
cublasLtCreate(&cublaslt_handle);
int8_gemm_wrapper = new cublasINT8MMWrapper(
cublaslt_handle,
this->stream,
cublas_algo_map,
cublas_wrapper_mutex,
use_ORDER_COL32_2R_4R4);
}
I8CUGEMM::~I8CUGEMM() {}
void I8CUGEMM::linear_a8_w8_o32(
torch::Tensor& input, // INT8
torch::Tensor& weight, // INT8
torch::Tensor& out // INT32
) {
int m = input.size(0);
int n = weight.size(0);
int k = input.size(1);
// Set data types
int8_t* input_ptr = input.data_ptr<int8_t>();
int8_t* weight_ptr = weight.data_ptr<int8_t>();
int32_t* output_ptr = out.data_ptr<int32_t>();
int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr,
weight_ptr);
}
void I8CUGEMM::linear_a8_w8_o32_(
torch::Tensor& input, // INT8
torch::Tensor& weight, // INT8
torch::Tensor& out // INT32
) {
int m = input.size(0);
int n = weight.size(0);
int k = input.size(1);
// Set data types
int8_t* input_ptr = input.data_ptr<int8_t>();
int8_t* weight_ptr = weight.data_ptr<int8_t>();
int32_t* output_ptr = out.data_ptr<int32_t>();
int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr,
weight_ptr);
}
void I8CUGEMM::linear_a8_w8_o8(
torch::Tensor& input, // INT8
torch::Tensor& weight, // INT8
torch::Tensor& out, // INT8
float alpha // FP32
) {
int m = input.size(0);
int n = weight.size(0);
int k = input.size(1);
// Set data types
int8_t* input_ptr = input.data_ptr<int8_t>();
int8_t* weight_ptr = weight.data_ptr<int8_t>();
int8_t* output_ptr = out.data_ptr<int8_t>();
int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr,
weight_ptr);
}
void I8CUGEMM::linear_a8_w8_o8_(
torch::Tensor& input, // INT8
torch::Tensor& weight, // INT8
torch::Tensor& out, // INT8
float alpha // FP32
) {
int m = input.size(0);
int n = weight.size(0);
int k = input.size(1);
// Set data types
int8_t* input_ptr = input.data_ptr<int8_t>();
int8_t* weight_ptr = weight.data_ptr<int8_t>();
int8_t* output_ptr = out.data_ptr<int8_t>();
int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr,
weight_ptr);
}