Skip to content

Commit

Permalink
[REFRACTOR] Cleanup backend and model_type usage (#276)
Browse files Browse the repository at this point in the history
* clearup backend and model_type usage

* Update gemma2.py

* Update gemma2.py

* cleanup self.backend

* use lambda function

* release memory after vllm test case finish

---------

Co-authored-by: LRL-ModelCloud <lrl@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
  • Loading branch information
3 people authored Jul 23, 2024
1 parent 3d88fdf commit 9197738
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
26 changes: 8 additions & 18 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
super().__init__()

self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.quantize_config = quantize_config
self.config = self.model.config
Expand Down Expand Up @@ -565,18 +564,8 @@ def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def generate(self, **kwargs):
"""shortcut for model.generate"""
from ..utils.sglang import sglang_generate
from ..utils.vllm import vllm_generate
if hasattr(self.model.config, "model_type") and self.model.config.model_type == "vllm":
with torch.inference_mode():
return vllm_generate(self.model, **kwargs)
elif hasattr(self.model.config, "model_type") and self.model.config.model_type == "sglang":
with torch.inference_mode():
return sglang_generate(**kwargs)
else:
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
Expand Down Expand Up @@ -859,7 +848,7 @@ def from_quantized(
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
):
if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if backend == BACKEND.VLLM:
import os
# to optimize vllm inference, set an environment variable 'VLLM_ATTENTION_BACKEND' to 'FLASHINFER'.
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'
Expand Down Expand Up @@ -936,7 +925,7 @@ def from_quantized(
if quantize_config.format != FORMAT.GPTQ:
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ: actual = {quantize_config.format}")
if backend == BACKEND.VLLM:
from ..utils.vllm import load_model_by_vllm
from ..utils.vllm import load_model_by_vllm, vllm_generate

model = load_model_by_vllm(
model=model_name_or_path,
Expand All @@ -945,18 +934,19 @@ def from_quantized(
)

model.config = model.llm_engine.model_config
model.config.model_type = "vllm"

cls.generate = lambda self, **kwargs: vllm_generate(self.model, **kwargs)

elif backend == BACKEND.SGLANG:
from ..utils.sglang import load_model_by_sglang
from ..utils.sglang import load_model_by_sglang, sglang_generate

model, hf_config = load_model_by_sglang(
model=model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
model.config = hf_config
model.config.model_type = "sglang"
cls.generate = lambda self, **kwargs: sglang_generate(self.model, **kwargs)
return cls(
model,
quantized=True,
Expand Down
28 changes: 22 additions & 6 deletions gptqmodel/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from logging import getLogger

from .base import BaseGPTQModel

from ..utils import BACKEND
logger = getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

SUPPORT_ERR = "Currently, only vLLM/SGLang with flashinfer enabled can correctly inference a quantized Gemma2-27B model. Pre-quantized model with sample vLLM code: https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit ."

class Gemma2GPTQ(BaseGPTQModel):
base_modules = ["model.embed_tokens", "model.norm"]

Expand All @@ -26,13 +28,27 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# There is an issue with duplicate outputs in the quantized gemma-2 model 27b with transformers.
# Until this issue is fixed, quantized gemma-2 27b model only support vLLM load.
if hasattr(self.model.config, "num_hidden_layers"):
num_hidden_layers = getattr(self.model.config, "num_hidden_layers")
# The gemma-2 model 9b has 42 hidden layers, while the gemma-2 model 27b has 46 hidden layers.
if num_hidden_layers > 42:
if self.quantized:
raise ValueError("Currently, only vllm can load the quantized gemma2-27b for proper inference. https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit is a quantized gemma-2-27b-it model, along with an example of loading it using vLLM.")
else:
logger.warning("Currently, only vllm can load the quantized gemma2-27b for proper inference. https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit is a quantized gemma-2-27b-it model, along with an example of loading it using vLLM.")
if not self.quantized:
logger.warning(SUPPORT_ERR)
return

# quantized gemma-2 27b model only support vLLM/SGLang load.
from ..utils.vllm import VLLM_AVAILABLE
if VLLM_AVAILABLE:
from vllm import LLM
if isinstance(self.model, LLM):
backend = BACKEND.VLLM

from ..utils.sglang import SGLANG_AVAILABLE
if SGLANG_AVAILABLE:
from sglang.srt.server import Runtime
if isinstance(self.model, Runtime):
backend = BACKEND.SGLANG

if backend not in [BACKEND.VLLM, BACKEND.SGLANG]:
raise ValueError(SUPPORT_ERR)

3 changes: 3 additions & 0 deletions gptqmodel/utils/sglang.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import multiprocessing as mp

import torch
from transformers import AutoConfig

try:
Expand Down Expand Up @@ -38,7 +39,9 @@ def generate(s, prompt, **kwargs):
def generate(s, prompt, **kwargs):
print(SGLANG_INSTALL_HINT)

@torch.inference_mode
def sglang_generate(
model,
**kwargs,
):
if not SGLANG_AVAILABLE:
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/utils/vllm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict

import torch

try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
Expand Down Expand Up @@ -42,6 +44,7 @@ def load_model_by_vllm(

return model

@torch.inference_mode
def vllm_generate(
model,
**kwargs,
Expand Down
16 changes: 14 additions & 2 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
Expand All @@ -9,7 +10,7 @@

import torch # noqa: E402
from gptqmodel import BACKEND, GPTQModel # noqa: E402

from vllm.distributed.parallel_state import destroy_model_parallel

class TestLoadVLLM(unittest.TestCase):

Expand Down Expand Up @@ -52,6 +53,12 @@ def test_load_vllm(self):
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris. 2. Name the capital of the United States. 3.")

destroy_model_parallel()
del model
gc.collect()
torch.cuda.empty_cache()


def test_load_shared_vllm(self):
model = GPTQModel.from_quantized(
self.SHARDED_MODEL_ID,
Expand All @@ -69,4 +76,9 @@ def test_load_shared_vllm(self):
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris.\n2. Who has a national flag with a white field surrounded by")
" Paris.\n2. Who has a national flag with a white field surrounded by")

destroy_model_parallel()
del model
gc.collect()
torch.cuda.empty_cache()

0 comments on commit 9197738

Please sign in to comment.