diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f557..cc56649917a8a 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -2067,7 +2067,7 @@ void gptq_shuffle const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94e..046f11d957bdd 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype): ).cuda() # Load the weights - vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 84a94091486d7..a8ec4dcfd6137 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -368,7 +368,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -402,10 +402,6 @@ def forward(self, input_): if self.base_layer.skip_bias_add else None) return output, output_bias - @property - def linear_weights(self): - return self.base_layer.linear_weights - @classmethod def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, @@ -505,7 +501,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -746,7 +742,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -838,7 +834,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x) + self.base_layer, x) _apply_lora( x, self.lora_a_stacked, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 8f42b3e8a4abe..3ca870742efc5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -28,19 +28,24 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - """Create weights for a linear layer.""" + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + + The weights will be set as attributes of the layer.""" raise NotImplementedError @abstractmethod def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """Apply the weights to the input tensor.""" + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = weights["weight"] + weight = layer.weight if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias @@ -111,12 +118,9 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) + self.linear_method.create_weights(self, self.input_size, + self.output_size, self.input_size, + self.output_size, self.params_dtype) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -126,7 +130,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output = self.linear_method.apply_weights(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -177,13 +181,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size, + self.output_size_per_partition, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -211,8 +215,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_, bias) + output_parallel = self.linear_method.apply_weights(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -523,13 +526,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size_per_partition, + self.output_size, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -569,7 +572,7 @@ def forward(self, input_): # Matrix multiply. output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_parallel) + self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index daea5ac73e429..98651aed8be0e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, }) - return { - "qweight": qweight, - "qzeros": qzeros, - "scales": scales, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - scales = weights["scales"] - qzeros = weights["qzeros"] + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) @@ -163,5 +166,5 @@ def apply_weights(self, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 757ab1af8392e..f370b94a210ee 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( @@ -179,37 +181,40 @@ def create_weights( "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) - return { - "qweight": qweight, - "g_idx": g_idx, - "qzeros": qzeros, - "scales": scales, - "exllama_state": exllama_state, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("g_idx", g_idx) + set_weight_attrs(g_idx, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.exllama_state = exllama_state def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] + qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + layer.g_idx.data = torch.empty((0, ), + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY, + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: - output = output + bias + output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index a6482c059cc41..bf0500f1155a1 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if params_dtype != torch.float16: @@ -187,21 +189,22 @@ def create_weights( dtype=torch.int), requires_grad=False) - return { - "B": qweight, - "s": scales, - "workspace": workspace, - } + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) def apply_weights( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weights["B"] - scales = weights["s"] - workspace = weights["workspace"] + qweight = layer.B + scales = layer.s + workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index bb295df2acc3f..661ff9c55d0d1 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -103,17 +104,18 @@ def create_weights(self, input_size_per_partition: int, set_weight_attrs(lookup_table, { "output_dim": 0, }) - return { - "qweight": qweight, - "lookup_table": lookup_table, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("lookup_table", lookup_table) + set_weight_attrs(lookup_table, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - lookup_table = weights["lookup_table"] + qweight = layer.qweight + lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): @@ -126,5 +128,5 @@ def apply_weights(self, ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape)