Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Kernel] Support bitsandbytes quantization and QLoRA #4776

Merged
merged 13 commits into from
Jun 1, 2024
Prev Previous commit
update per comments
  • Loading branch information
chenqianfzh committed Jun 1, 2024
commit e16bcb69495540b21a3bd9423cdd5df8a78405ea
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_name(self) -> str:

@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
return [torch.float32, torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(self) -> int:
Expand Down Expand Up @@ -123,7 +123,7 @@ def create_weights(self, layer: torch.nn.Module,
qweight,
{
"input_dim": 0,
# In bitsandbytes, a tensor of shape [n,m] is qunatized to
# In bitsandbytes, a tensor of shape [n,m] is quantized to
#[n*m/pack_ratio, 1],so the output_dim is 0
"output_dim": 0,
"pack_factor": quant_ratio,
Expand All @@ -140,7 +140,7 @@ def apply(self,
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit

orginal_type = x.dtype
original_type = x.dtype
bf_x = x.to(torch.bfloat16)
chenqianfzh marked this conversation as resolved.
Show resolved Hide resolved

qweight = layer.qweight
Expand All @@ -167,7 +167,7 @@ def apply(self,

current_index += output_size

out = out.to(orginal_type)
out = out.to(original_type)

if bias is not None:
out += bias
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def save_model(


class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quabntization."""
"""Model loader to load model weights with BitAndBytes quantization."""

default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
Expand Down Expand Up @@ -654,7 +654,6 @@ def _get_quantized_weights_iterator(
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,

as well as the quantization state dictionary."""

# only load the bitsandbytes module when needed
Expand Down
Loading