Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFRACTOR] Cleanup backend and model_type usage #276

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
super().__init__()

self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.quantize_config = quantize_config
self.config = self.model.config
Expand Down Expand Up @@ -559,18 +558,8 @@ def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def generate(self, **kwargs):
"""shortcut for model.generate"""
from ..utils.sglang import sglang_generate
from ..utils.vllm import vllm_generate
if hasattr(self.model.config, "model_type") and self.model.config.model_type == "vllm":
with torch.inference_mode():
return vllm_generate(self.model, **kwargs)
elif hasattr(self.model.config, "model_type") and self.model.config.model_type == "sglang":
with torch.inference_mode():
return sglang_generate(**kwargs)
else:
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
Expand Down Expand Up @@ -853,7 +842,7 @@ def from_quantized(
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
):
if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if backend == BACKEND.VLLM:
import os
# to optimize vllm inference, set an environment variable 'VLLM_ATTENTION_BACKEND' to 'FLASHINFER'.
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'
Expand Down Expand Up @@ -930,7 +919,7 @@ def from_quantized(
if quantize_config.format != FORMAT.GPTQ:
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ: actual = {quantize_config.format}")
if backend == BACKEND.VLLM:
from ..utils.vllm import load_model_by_vllm
from ..utils.vllm import load_model_by_vllm, vllm_generate

model = load_model_by_vllm(
model=model_name_or_path,
Expand All @@ -939,18 +928,19 @@ def from_quantized(
)

model.config = model.llm_engine.model_config
model.config.model_type = "vllm"

cls.generate = lambda self, **kwargs: vllm_generate(self.model, **kwargs)

elif backend == BACKEND.SGLANG:
from ..utils.sglang import load_model_by_sglang
from ..utils.sglang import load_model_by_sglang, sglang_generate

model, hf_config = load_model_by_sglang(
model=model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
model.config = hf_config
model.config.model_type = "sglang"
cls.generate = lambda self, **kwargs: sglang_generate(self.model, **kwargs)
return cls(
model,
quantized=True,
Expand Down Expand Up @@ -1182,7 +1172,7 @@ def skip(*args, **kwargs):
if is_sharded:
raise ValueError(
"The loading of sharded checkpoints with BitBLAS is currently not supported. Please raise an issue in GPTQModel repository.")

# Prepare model for bitblas load.
# If is bitblas serialized load then load directly. Otherwise, convert to bitblas.
model = prepare_model_for_bitblas_load(
Expand Down
28 changes: 22 additions & 6 deletions gptqmodel/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from logging import getLogger

from .base import BaseGPTQModel

from ..utils import BACKEND
logger = getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

SUPPORT_ERR = "Currently, only vLLM/SGLang with flashinfer enabled can correctly inference a quantized Gemma2-27B model. Pre-quantized model with sample vLLM code: https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit ."

class Gemma2GPTQ(BaseGPTQModel):
base_modules = ["model.embed_tokens", "model.norm"]

Expand All @@ -26,13 +28,27 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# There is an issue with duplicate outputs in the quantized gemma-2 model 27b with transformers.
# Until this issue is fixed, quantized gemma-2 27b model only support vLLM load.
if hasattr(self.model.config, "num_hidden_layers"):
num_hidden_layers = getattr(self.model.config, "num_hidden_layers")
# The gemma-2 model 9b has 42 hidden layers, while the gemma-2 model 27b has 46 hidden layers.
if num_hidden_layers > 42:
if self.quantized:
raise ValueError("Currently, only vllm can load the quantized gemma2-27b for proper inference. https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit is a quantized gemma-2-27b-it model, along with an example of loading it using vLLM.")
else:
logger.warning("Currently, only vllm can load the quantized gemma2-27b for proper inference. https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit is a quantized gemma-2-27b-it model, along with an example of loading it using vLLM.")
if not self.quantized:
logger.warning(SUPPORT_ERR)
return

# quantized gemma-2 27b model only support vLLM/SGLang load.
from ..utils.vllm import VLLM_AVAILABLE
if VLLM_AVAILABLE:
from vllm import LLM
if isinstance(self.model, LLM):
backend = BACKEND.VLLM

from ..utils.sglang import SGLANG_AVAILABLE
if SGLANG_AVAILABLE:
from sglang.srt.server import Runtime
if isinstance(self.model, Runtime):
backend = BACKEND.SGLANG

if backend not in [BACKEND.VLLM, BACKEND.SGLANG]:
raise ValueError(SUPPORT_ERR)

3 changes: 3 additions & 0 deletions gptqmodel/utils/sglang.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import multiprocessing as mp

import torch
from transformers import AutoConfig

try:
Expand Down Expand Up @@ -38,7 +39,9 @@ def generate(s, prompt, **kwargs):
def generate(s, prompt, **kwargs):
print(SGLANG_INSTALL_HINT)

@torch.inference_mode
def sglang_generate(
model,
**kwargs,
):
if not SGLANG_AVAILABLE:
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/utils/vllm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict

import torch

try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
Expand Down Expand Up @@ -42,6 +44,7 @@ def load_model_by_vllm(

return model

@torch.inference_mode
def vllm_generate(
model,
**kwargs,
Expand Down
16 changes: 14 additions & 2 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
Expand All @@ -9,7 +10,7 @@

import torch # noqa: E402
from gptqmodel import BACKEND, GPTQModel # noqa: E402

from vllm.distributed.parallel_state import destroy_model_parallel

class TestLoadVLLM(unittest.TestCase):

Expand Down Expand Up @@ -52,6 +53,12 @@ def test_load_vllm(self):
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris. 2. Name the capital of the United States. 3.")

destroy_model_parallel()
del model
gc.collect()
torch.cuda.empty_cache()


def test_load_shared_vllm(self):
model = GPTQModel.from_quantized(
self.SHARDED_MODEL_ID,
Expand All @@ -69,4 +76,9 @@ def test_load_shared_vllm(self):
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris.\n2. Who has a national flag with a white field surrounded by")
" Paris.\n2. Who has a national flag with a white field surrounded by")

destroy_model_parallel()
del model
gc.collect()
torch.cuda.empty_cache()
Loading