Skip to content

Commit

Permalink
[Core] Set linear_weights directly on the layer (vllm-project#3977)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored and joerunde committed Apr 18, 2024
1 parent 1b69da4 commit ae3bae0
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 102 deletions.
2 changes: 1 addition & 1 deletion csrc/quantization/gptq/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,7 @@ void gptq_shuffle
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 32 / bit,
q_weight.size(1),
bit
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).cuda()

# Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
Expand Down
12 changes: 4 additions & 8 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def set_mapping(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -402,10 +402,6 @@ def forward(self, input_):
if self.base_layer.skip_bias_add else None)
return output, output_bias

@property
def linear_weights(self):
return self.base_layer.linear_weights

@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
Expand Down Expand Up @@ -505,7 +501,7 @@ def set_lora(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -746,7 +742,7 @@ def set_lora(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -838,7 +834,7 @@ def set_mapping(

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x)
self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
Expand Down
77 changes: 40 additions & 37 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""

@abstractmethod
def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Create weights for a linear layer."""
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError

@abstractmethod
def apply_weights(self,
weights: Dict[str, torch.Tensor],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor."""
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError


Expand All @@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight}
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def apply_weights(self,
weights: Dict[str, torch.Tensor],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"]
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
Expand Down Expand Up @@ -111,12 +118,9 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
self.linear_method.create_weights(self, self.input_size,
self.output_size, self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
Expand All @@ -126,7 +130,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
output = self.linear_method.apply_weights(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

Expand Down Expand Up @@ -177,13 +181,13 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size_per_partition, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
self.linear_method.create_weights(self,
self.input_size,
self.output_size_per_partition,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
Expand Down Expand Up @@ -211,8 +215,7 @@ def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_, bias)
output_parallel = self.linear_method.apply_weights(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
Expand Down Expand Up @@ -523,13 +526,13 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size_per_partition, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
self.linear_method.create_weights(self,
self.input_size_per_partition,
self.output_size,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)

if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
Expand Down Expand Up @@ -569,7 +572,7 @@ def forward(self, input_):

# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_parallel)
self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
Expand Down
29 changes: 16 additions & 13 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config

def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
Expand Down Expand Up @@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int,
"input_dim": 0,
"output_dim": 1,
})
return {
"qweight": qweight,
"qzeros": qzeros,
"scales": scales,
}

layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)

def apply_weights(self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
scales = weights["scales"]
qzeros = weights["qzeros"]
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
Expand All @@ -163,5 +166,5 @@ def apply_weights(self,
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
out.add_(bias)
return out.reshape(out_shape)
47 changes: 26 additions & 21 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
**extra_weight_attrs,
):
del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
Expand Down Expand Up @@ -179,37 +181,40 @@ def create_weights(
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
})
return {
"qweight": qweight,
"g_idx": g_idx,
"qzeros": qzeros,
"scales": scales,
"exllama_state": exllama_state,
}

layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("g_idx", g_idx)
set_weight_attrs(g_idx, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)

layer.exllama_state = exllama_state

def apply_weights(self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
torch.int)
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
layer.g_idx.data = torch.empty((0, ),
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY,
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output = output + bias
output.add_(bias)
return output.reshape(out_shape)
23 changes: 13 additions & 10 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
**extra_weight_attrs,
):
del output_size # Unused.

if params_dtype != torch.float16:
Expand Down Expand Up @@ -187,21 +189,22 @@ def create_weights(
dtype=torch.int),
requires_grad=False)

return {
"B": qweight,
"s": scales,
"workspace": workspace,
}
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def apply_weights(
self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = weights["B"]
scales = weights["s"]
workspace = weights["workspace"]
qweight = layer.B
scales = layer.s
workspace = layer.workspace

x_2d = x.view(-1, x.shape[-1])

Expand Down
Loading

0 comments on commit ae3bae0

Please sign in to comment.