Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions server/text_generation_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight

if SYSTEM == "rocm":
try:
Expand Down Expand Up @@ -155,6 +153,8 @@ def get_linear(weight, bias, quantize):
quant_type="nf4",
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight

if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
Expand All @@ -165,6 +165,8 @@ def get_linear(weight, bias, quantize):
linear = ExllamaQuantLinear(weight, bias)

elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight

if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM != "xpu":
Expand Down Expand Up @@ -198,6 +197,8 @@ def _load_gqa(config, prefix: str, weights):
v_stop = v_offset + (rank + 1) * kv_block_size

if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight

try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple

from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.attention import (
paged_attention,
attention,
Expand Down Expand Up @@ -39,6 +38,8 @@ def load_multi_mqa(
def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
from text_generation_server.layers.gptq import GPTQWeight

if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size()
rank = weights.process_group.rank()
Expand Down
6 changes: 4 additions & 2 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_once


Expand Down Expand Up @@ -221,6 +219,8 @@ def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):

def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight

try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
Expand All @@ -247,6 +247,8 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight

try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
Expand Down