Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support loading GGUF model #5191

Merged
merged 76 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
1ffda2e
init gguf loading support
Isotr0py Jun 2, 2024
f3058b1
add gguf running support
Isotr0py Jun 2, 2024
259d5b5
Fix numpy warning
Isotr0py Jun 2, 2024
0035bdf
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jun 2, 2024
995f98e
fix gguf load format
Isotr0py Jun 2, 2024
d116f2e
add more example prompts
Isotr0py Jun 2, 2024
f387f9e
update requirements.txt
Isotr0py Jun 2, 2024
516552a
add dequant runtime
Isotr0py Jun 3, 2024
de5950d
remove debug code
Isotr0py Jun 3, 2024
5bda5f0
format code
Isotr0py Jun 4, 2024
980c018
update gguf example
Isotr0py Jun 4, 2024
f969b36
Merge branch 'main' into gguf
Isotr0py Jun 4, 2024
e99f521
Merge branch 'vllm-project:main' into gguf
Isotr0py Jun 5, 2024
9d36996
Fix requirements.txt
Isotr0py Jun 5, 2024
3a18502
rename ggml -> gguf
Isotr0py Jun 5, 2024
e194e28
auto detect gguf quant and format
Isotr0py Jun 5, 2024
164b643
use autotokenizer to load gguf tokenizer
Isotr0py Jun 5, 2024
b055fb3
Add runtime dequantization for all layers
Isotr0py Jun 6, 2024
c93c44e
Merge branch 'main' into gguf
Isotr0py Jun 18, 2024
8960270
port gguf cuda kernel
Isotr0py Jun 19, 2024
1d0c6a4
add qwen2 support and gguf mmq for linear
Isotr0py Jun 21, 2024
957faec
remove transformers load_dequant_gguf_tensor
Isotr0py Jun 21, 2024
4555cf5
reorder gguf weight iterator
Isotr0py Jun 22, 2024
7f7af2b
fix imatrix
Isotr0py Jun 22, 2024
87078be
fix imatrix
Isotr0py Jun 22, 2024
ca39edf
refactor, fix column parallel
Isotr0py Jun 22, 2024
cf03757
refactor gguf_kernel and remove dmmv
Isotr0py Jun 24, 2024
c2524a8
refactor to unmerge weights for gguf
Isotr0py Jun 29, 2024
446c64a
revert get_quantization_config
Isotr0py Jun 29, 2024
dc43654
revert get_quantization_config
Isotr0py Jun 29, 2024
2861670
revert qwen2
Isotr0py Jun 29, 2024
1622966
add quant vocal embeddings
Isotr0py Jun 29, 2024
c4d4f96
support quantized parallelhead
Isotr0py Jun 29, 2024
9a99252
revert qwen2
Isotr0py Jun 29, 2024
bc1ab48
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jul 3, 2024
3fad5bd
rebase gguf support
Isotr0py Jul 3, 2024
409bed3
format code
Isotr0py Jul 3, 2024
b38bd1d
format code
Isotr0py Jul 3, 2024
3586f12
support qwen2 gguf
Isotr0py Jul 4, 2024
8a56d55
Merge branch 'main' into gguf
Isotr0py Jul 4, 2024
defe23f
fix gguf loader
Isotr0py Jul 4, 2024
6c4300e
add gguf test
Isotr0py Jul 4, 2024
266447b
format code
Isotr0py Jul 4, 2024
d5a7e2f
format code
Isotr0py Jul 4, 2024
6026e02
remove archs<7.0 in cmakelists
Isotr0py Jul 4, 2024
9dc8794
fix a typo
Isotr0py Jul 4, 2024
ef9b8a3
format code
Isotr0py Jul 4, 2024
b708ce6
format code
Isotr0py Jul 4, 2024
be51a27
fix failed model test
Isotr0py Jul 5, 2024
1bd7d16
Merge branch 'vllm-project:main' into gguf
Isotr0py Jul 7, 2024
c155f74
Merge branch 'main' into gguf
Isotr0py Jul 10, 2024
e49f96e
add imatrix and qwen2 test
Isotr0py Jul 10, 2024
af0c051
reorganize gguf kernel
Isotr0py Jul 12, 2024
0ce3961
exclude gguf copied code
Isotr0py Jul 12, 2024
e599b07
refactor to merge weights
Isotr0py Jul 12, 2024
25dcc08
forma code
Isotr0py Jul 12, 2024
eed9a23
format code
Isotr0py Jul 12, 2024
6e5330d
import gguf
Isotr0py Jul 12, 2024
e5a61be
import gguf
Isotr0py Jul 13, 2024
64c5375
refactor quantized vocal embedding
Isotr0py Jul 13, 2024
86ef2b5
optimize docs
Isotr0py Jul 14, 2024
7ccfacb
add docs
Isotr0py Jul 17, 2024
28dc7b6
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jul 17, 2024
1b39fbc
fix llama embed quant
Isotr0py Jul 17, 2024
d413f60
Fix CUDA graph with gguf
Isotr0py Jul 18, 2024
1868a94
Merge remote-tracking branch 'upstream/main' into gguf
Isotr0py Jul 28, 2024
b4e2f29
fix quant embeddings
Isotr0py Jul 28, 2024
2cc6753
Merge branch 'main' into gguf
mgoin Jul 31, 2024
db54a19
Fix embedding method and format
mgoin Jul 31, 2024
2549c3e
Cleanup linear comments
mgoin Jul 31, 2024
0890fa9
move gguf to cuda requirements
Isotr0py Aug 1, 2024
5166ac9
raise error for gguf when tp>1
Isotr0py Aug 1, 2024
26349db
Merge branch 'main' into gguf
mgoin Aug 5, 2024
73da240
Last round of cleanup
mgoin Aug 5, 2024
1c83d63
Improve qweight_type size calc
mgoin Aug 5, 2024
1139e7b
Fix lm head tests
mgoin Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/gguf_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from huggingface_hub import hf_hub_download

from vllm import LLM, SamplingParams


def run_gguf_inference(model_path):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model=model_path,
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
load_format="gguf",
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
quantization="ggml")

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
model = hf_hub_download(repo_id, filename=filename)
run_gguf_inference(model)
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy
requests
py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
transformers >= 4.41.0 # Required for StarCoder2 & Llava, Llama 3.
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
Expand Down
6 changes: 6 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,11 @@ sentence-transformers # required for embedding
# Benchmarking
aiohttp

# Multimodal
pillow
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

# GGUF
gguf
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

# quantization
bitsandbytes==0.42.0
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class LoadFormat(str, enum.Enum):
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"


Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.ggml import GGMLConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
Expand All @@ -27,6 +28,7 @@
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"ggml": GGMLConfig,
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
Expand Down
100 changes: 100 additions & 0 deletions vllm/model_executor/layers/quantization/ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs


class GGMLConfig(QuantizationConfig):
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
"""Config class for GGML."""

def __init__(self, ) -> None:
pass

def __repr__(self) -> str:
return ("GGMLConfig()")

def get_name(self) -> str:
return "ggml"

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]

def get_min_capability(self) -> int:
return 70

@classmethod
def get_config_filenames(cls) -> List[str]:
return [] # no extra configs.

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGMLConfig":
return cls()

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GGMLLinearMethod"]:
if isinstance(layer, LinearBase):
return GGMLLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class GGMLLinearMethod(LinearMethodBase):
"""Linear method for GGML.

Args:
quant_config: The GGML quantization config.
"""

def __init__(self, quant_config: GGMLConfig):
self.quant_config = quant_config
self.block_size = 32

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
quants = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
set_weight_attrs(quants, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(quants, extra_weight_attrs)
layer.register_parameter("quants", quants)
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

scales = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.block_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 1,
"output_dim": 0,
"ggml_scales": True
})
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("scales", scales)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# dequantized for Q4_0 and Q8_0
shape = layer.quants.shape
out = layer.quants.reshape(-1, self.block_size) * layer.scales.reshape(
-1, 1)
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
out = torch.matmul(x, out.reshape(shape).T)
if bias is not None:
out.add_(bias)
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
return out
76 changes: 76 additions & 0 deletions vllm/model_executor/model_loader/gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# adapted from
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/integrations/ggml.py
"""
Integration with GGML / The file is copied and adapted from https://github.com/99991/pygguf
with extra methods beings exposed
"""
import numpy as np
import torch
from transformers.integrations.ggml import (GGML_BLOCK_SIZES, GGML_TYPES,
load_dequant_gguf_tensor)


def convert_tensor_q4_0(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11
block_size = GGML_BLOCK_SIZES["Q4_0"]
num_blocks = len(data) // block_size

data_f16 = np.frombuffer(data,
dtype=np.float16).reshape(num_blocks,
block_size // 2)
data_u8 = np.frombuffer(data,
dtype=np.uint8).reshape(num_blocks, block_size)

# The scales are stored on the first 2 bytes
# and the rest corresponds to the quants
scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)

# the rest of the bytes corresponds to the quants
# - we discard the first two bytes
quants = data_u8[:, 2:]

ql = (quants[:, :] & 0xF).astype(np.int8) - 8
qr = (quants[:, :] >> 4).astype(np.int8) - 8

# Use hstack
quants = np.hstack([ql, qr])

return scales, quants


def convert_tensor_q8_0(data):
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
block_size = GGML_BLOCK_SIZES["Q8_0"]
num_blocks = len(data) // block_size

scales = (np.frombuffer(data, dtype=np.float16).reshape(
num_blocks, 1 + 16)[:, :1].astype(np.float32))
quants = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:,
2:]

return scales, quants


def load_gguf_tensor(tensor):
shape, ggml_type, data = tensor.shape, tensor.tensor_type, tensor.data

scales = None
if ggml_type == GGML_TYPES["Q8_0"] and "blk" in tensor.name:
scales, quants = convert_tensor_q8_0(data)
elif ggml_type == GGML_TYPES["Q4_0"] and "blk" in tensor.name:
scales, quants = convert_tensor_q4_0(data)
else:
quants = load_dequant_gguf_tensor(shape, ggml_type, data)
quants = torch.from_numpy(quants.copy())
return scales, quants

scales_shape = (int(shape[0] // 32), shape[1])
scales = scales.reshape(scales_shape[::-1])
quants = quants.reshape(shape[::-1])
scales = torch.from_numpy(scales.copy())
quants = torch.from_numpy(quants.copy())
return scales, quants
73 changes: 63 additions & 10 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.ggml import GGMLConfig
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
tensorizer_weights_iterator)
Expand All @@ -30,8 +31,10 @@
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
get_quant_config, gguf_dequant_weights_iterator,
gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.utils import set_weight_attrs

Expand All @@ -44,14 +47,15 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
if not isinstance(quant_config, GGMLConfig):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
f"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
Expand Down Expand Up @@ -781,6 +785,52 @@ def load_model(self, *, model_config: ModelConfig,
return model.eval()


class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")

def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
else:
raise ValueError(f"{model_name_or_path} is not a file.")

def _get_weights_iterator(
self, model_name_or_path: str, quantization: str
) -> Generator[Tuple[str, torch.Tensor], None, None]:
local_model_path = self._prepare_weights(model_name_or_path)
if quantization == "ggml":
return gguf_quant_weights_iterator(local_model_path)
else:
return gguf_dequant_weights_iterator(local_model_path)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.quantization))
return model


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -799,4 +849,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)

if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)

return DefaultModelLoader(load_config)
Loading
Loading