Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support loading GGUF model #5191

Merged
merged 76 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
1ffda2e
init gguf loading support
Isotr0py Jun 2, 2024
f3058b1
add gguf running support
Isotr0py Jun 2, 2024
259d5b5
Fix numpy warning
Isotr0py Jun 2, 2024
0035bdf
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jun 2, 2024
995f98e
fix gguf load format
Isotr0py Jun 2, 2024
d116f2e
add more example prompts
Isotr0py Jun 2, 2024
f387f9e
update requirements.txt
Isotr0py Jun 2, 2024
516552a
add dequant runtime
Isotr0py Jun 3, 2024
de5950d
remove debug code
Isotr0py Jun 3, 2024
5bda5f0
format code
Isotr0py Jun 4, 2024
980c018
update gguf example
Isotr0py Jun 4, 2024
f969b36
Merge branch 'main' into gguf
Isotr0py Jun 4, 2024
e99f521
Merge branch 'vllm-project:main' into gguf
Isotr0py Jun 5, 2024
9d36996
Fix requirements.txt
Isotr0py Jun 5, 2024
3a18502
rename ggml -> gguf
Isotr0py Jun 5, 2024
e194e28
auto detect gguf quant and format
Isotr0py Jun 5, 2024
164b643
use autotokenizer to load gguf tokenizer
Isotr0py Jun 5, 2024
b055fb3
Add runtime dequantization for all layers
Isotr0py Jun 6, 2024
c93c44e
Merge branch 'main' into gguf
Isotr0py Jun 18, 2024
8960270
port gguf cuda kernel
Isotr0py Jun 19, 2024
1d0c6a4
add qwen2 support and gguf mmq for linear
Isotr0py Jun 21, 2024
957faec
remove transformers load_dequant_gguf_tensor
Isotr0py Jun 21, 2024
4555cf5
reorder gguf weight iterator
Isotr0py Jun 22, 2024
7f7af2b
fix imatrix
Isotr0py Jun 22, 2024
87078be
fix imatrix
Isotr0py Jun 22, 2024
ca39edf
refactor, fix column parallel
Isotr0py Jun 22, 2024
cf03757
refactor gguf_kernel and remove dmmv
Isotr0py Jun 24, 2024
c2524a8
refactor to unmerge weights for gguf
Isotr0py Jun 29, 2024
446c64a
revert get_quantization_config
Isotr0py Jun 29, 2024
dc43654
revert get_quantization_config
Isotr0py Jun 29, 2024
2861670
revert qwen2
Isotr0py Jun 29, 2024
1622966
add quant vocal embeddings
Isotr0py Jun 29, 2024
c4d4f96
support quantized parallelhead
Isotr0py Jun 29, 2024
9a99252
revert qwen2
Isotr0py Jun 29, 2024
bc1ab48
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jul 3, 2024
3fad5bd
rebase gguf support
Isotr0py Jul 3, 2024
409bed3
format code
Isotr0py Jul 3, 2024
b38bd1d
format code
Isotr0py Jul 3, 2024
3586f12
support qwen2 gguf
Isotr0py Jul 4, 2024
8a56d55
Merge branch 'main' into gguf
Isotr0py Jul 4, 2024
defe23f
fix gguf loader
Isotr0py Jul 4, 2024
6c4300e
add gguf test
Isotr0py Jul 4, 2024
266447b
format code
Isotr0py Jul 4, 2024
d5a7e2f
format code
Isotr0py Jul 4, 2024
6026e02
remove archs<7.0 in cmakelists
Isotr0py Jul 4, 2024
9dc8794
fix a typo
Isotr0py Jul 4, 2024
ef9b8a3
format code
Isotr0py Jul 4, 2024
b708ce6
format code
Isotr0py Jul 4, 2024
be51a27
fix failed model test
Isotr0py Jul 5, 2024
1bd7d16
Merge branch 'vllm-project:main' into gguf
Isotr0py Jul 7, 2024
c155f74
Merge branch 'main' into gguf
Isotr0py Jul 10, 2024
e49f96e
add imatrix and qwen2 test
Isotr0py Jul 10, 2024
af0c051
reorganize gguf kernel
Isotr0py Jul 12, 2024
0ce3961
exclude gguf copied code
Isotr0py Jul 12, 2024
e599b07
refactor to merge weights
Isotr0py Jul 12, 2024
25dcc08
forma code
Isotr0py Jul 12, 2024
eed9a23
format code
Isotr0py Jul 12, 2024
6e5330d
import gguf
Isotr0py Jul 12, 2024
e5a61be
import gguf
Isotr0py Jul 13, 2024
64c5375
refactor quantized vocal embedding
Isotr0py Jul 13, 2024
86ef2b5
optimize docs
Isotr0py Jul 14, 2024
7ccfacb
add docs
Isotr0py Jul 17, 2024
28dc7b6
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jul 17, 2024
1b39fbc
fix llama embed quant
Isotr0py Jul 17, 2024
d413f60
Fix CUDA graph with gguf
Isotr0py Jul 18, 2024
1868a94
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jul 28, 2024
b4e2f29
fix quant embeddings
Isotr0py Jul 28, 2024
2cc6753
Merge branch 'main' into gguf
mgoin Jul 31, 2024
db54a19
Fix embedding method and format
mgoin Jul 31, 2024
2549c3e
Cleanup linear comments
mgoin Jul 31, 2024
0890fa9
move gguf to cuda requirements
Isotr0py Aug 1, 2024
5166ac9
raise error for gguf when tp>1
Isotr0py Aug 1, 2024
26349db
Merge branch 'main' into gguf
mgoin Aug 5, 2024
73da240
Last round of cleanup
mgoin Aug 5, 2024
1c83d63
Improve qweight_type size calc
mgoin Aug 5, 2024
1139e7b
Fix lm head tests
mgoin Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ jobs:
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/quantization/gguf/ggml-common.h'
'csrc/quantization/gguf/dequantize.cuh'
'csrc/quantization/gguf/vecdotq.cuh'
'csrc/quantization/gguf/mmq.cuh'
'csrc/quantization/gguf/mmvq.cuh'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
Expand Down
9 changes: 9 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);

torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m,
int64_t n);

torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type,
int64_t row);

torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type,
int64_t row);

torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
Expand Down
Loading
Loading