Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Set linear_weights directly on the layer #3977

Merged
merged 9 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/quantization/gptq/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,7 @@ void gptq_shuffle
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
robertgshaw2-redhat marked this conversation as resolved.
Show resolved Hide resolved
q_weight.size(0) * 32 / bit,
q_weight.size(1),
bit
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).cuda()

# Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
Expand Down
12 changes: 4 additions & 8 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def set_mapping(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -402,10 +402,6 @@ def forward(self, input_):
if self.base_layer.skip_bias_add else None)
return output, output_bias

@property
def linear_weights(self):
return self.base_layer.linear_weights

@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
Expand Down Expand Up @@ -505,7 +501,7 @@ def set_lora(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -746,7 +742,7 @@ def set_lora(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -838,7 +834,7 @@ def set_mapping(

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x)
self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
Expand Down
77 changes: 40 additions & 37 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -29,19 +29,24 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""

@abstractmethod
def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Create weights for a linear layer."""
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just be an optional weight_loader?

I see the value of enabling kwargs here for future extensibility, but I don't see a case that exists yet other than weight_loader so far, perhaps making the argument explicit is better until we have a reason to allow kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, making it a kwarg will be easier as we don't need extra handling to not set the weight_loader if it's left unspecified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

"""Create weights for a linear layer.

The weights will be set as attributes of the layer."""
raise NotImplementedError

@abstractmethod
def apply_weights(self,
weights: Dict[str, torch.Tensor],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor."""
"""Apply the weights in layer to the input tensor.

Expects create_weights to have been called before on the layer."""
raise NotImplementedError


Expand All @@ -56,22 +61,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight}
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def apply_weights(self,
weights: Dict[str, torch.Tensor],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"]
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
Expand Down Expand Up @@ -112,12 +119,9 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
self.linear_method.create_weights(self, self.input_size,
self.output_size, self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
Expand All @@ -127,7 +131,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
output = self.linear_method.apply_weights(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

Expand Down Expand Up @@ -178,13 +182,13 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size_per_partition, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
self.linear_method.create_weights(self,
self.input_size,
self.output_size_per_partition,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
Expand Down Expand Up @@ -212,8 +216,7 @@ def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_, bias)
output_parallel = self.linear_method.apply_weights(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
Expand Down Expand Up @@ -524,13 +527,13 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size_per_partition, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
self.linear_method.create_weights(self,
self.input_size_per_partition,
self.output_size,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)

if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
Expand Down Expand Up @@ -570,7 +573,7 @@ def forward(self, input_):

# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_parallel)
self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
Expand Down
29 changes: 16 additions & 13 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config

def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
Expand Down Expand Up @@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int,
"input_dim": 0,
"output_dim": 1,
})
return {
"qweight": qweight,
"qzeros": qzeros,
"scales": scales,
}

layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)

def apply_weights(self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
scales = weights["scales"]
qzeros = weights["qzeros"]
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
Expand All @@ -163,5 +166,5 @@ def apply_weights(self,
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
out.add_(bias)
return out.reshape(out_shape)
47 changes: 26 additions & 21 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
**extra_weight_attrs,
):
del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
Expand Down Expand Up @@ -179,37 +181,40 @@ def create_weights(
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
})
return {
"qweight": qweight,
"g_idx": g_idx,
"qzeros": qzeros,
"scales": scales,
"exllama_state": exllama_state,
}

layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("g_idx", g_idx)
set_weight_attrs(g_idx, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)

layer.exllama_state = exllama_state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LinearLayer does not need to know this info, so I think it should be GPTQLinearMethod.exllama_state ... It probably should not have been in the weights dict before

This will avoid having this dangling member

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah actually I don't think this is possible since GPTQLinearMethod is a singleton across all layers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good


def apply_weights(self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
torch.int)
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
layer.g_idx.data = torch.empty((0, ),
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY,
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output = output + bias
output.add_(bias)
return output.reshape(out_shape)
23 changes: 13 additions & 10 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
**extra_weight_attrs,
):
del output_size # Unused.

if params_dtype != torch.float16:
Expand Down Expand Up @@ -187,21 +189,22 @@ def create_weights(
dtype=torch.int),
requires_grad=False)

return {
"B": qweight,
"s": scales,
"workspace": workspace,
}
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def apply_weights(
self,
weights: Dict[str, Any],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = weights["B"]
scales = weights["s"]
workspace = weights["workspace"]
qweight = layer.B
scales = layer.s
workspace = layer.workspace

x_2d = x.view(-1, x.shape[-1])

Expand Down
Loading
Loading