Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1d859cc

Browse files
Yard1andy-neuma
authored andcommitted
[Core] Set linear_weights directly on the layer (vllm-project#3977)
1 parent 33a59a3 commit 1d859cc

File tree

8 files changed

+114
-102
lines changed

8 files changed

+114
-102
lines changed

csrc/quantization/gptq/q_gemm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2067,7 +2067,7 @@ void gptq_shuffle
20672067
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
20682068
vllm::gptq::shuffle_exllama_weight(
20692069
(uint32_t*) q_weight.data_ptr(),
2070-
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
2070+
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
20712071
q_weight.size(0) * 32 / bit,
20722072
q_weight.size(1),
20732073
bit

tests/kernels/test_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
7373
).cuda()
7474

7575
# Load the weights
76-
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
76+
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
7777
for i in range(config.num_local_experts):
7878
weights = (hf_moe.experts[i].w1.weight.data,
7979
hf_moe.experts[i].w3.weight.data)

vllm/lora/layers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def set_mapping(
368368
def apply_weights(self, x: torch.Tensor,
369369
bias: Optional[torch.Tensor]) -> torch.Tensor:
370370
output = self.base_layer.linear_method.apply_weights(
371-
self.base_layer.linear_weights, x, bias)
371+
self.base_layer, x, bias)
372372
_apply_lora(
373373
x,
374374
self.lora_a_stacked,
@@ -402,10 +402,6 @@ def forward(self, input_):
402402
if self.base_layer.skip_bias_add else None)
403403
return output, output_bias
404404

405-
@property
406-
def linear_weights(self):
407-
return self.base_layer.linear_weights
408-
409405
@classmethod
410406
def can_replace_layer(cls, source_layer: nn.Module,
411407
lora_config: LoRAConfig, packed_modules_list: List,
@@ -505,7 +501,7 @@ def set_lora(
505501
def apply_weights(self, x: torch.Tensor,
506502
bias: Optional[torch.Tensor]) -> torch.Tensor:
507503
output = self.base_layer.linear_method.apply_weights(
508-
self.base_layer.linear_weights, x, bias)
504+
self.base_layer, x, bias)
509505
_apply_lora_packed_nslice(
510506
x,
511507
self.lora_a_stacked,
@@ -746,7 +742,7 @@ def set_lora(
746742
def apply_weights(self, x: torch.Tensor,
747743
bias: Optional[torch.Tensor]) -> torch.Tensor:
748744
output = self.base_layer.linear_method.apply_weights(
749-
self.base_layer.linear_weights, x, bias)
745+
self.base_layer, x, bias)
750746
_apply_lora_packed_nslice(
751747
x,
752748
self.lora_a_stacked,
@@ -838,7 +834,7 @@ def set_mapping(
838834

839835
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
840836
output = self.base_layer.linear_method.apply_weights(
841-
self.base_layer.linear_weights, x)
837+
self.base_layer, x)
842838
_apply_lora(
843839
x,
844840
self.lora_a_stacked,

vllm/model_executor/layers/linear.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Optional
2+
from typing import List, Optional
33

44
import torch
55
import torch.nn.functional as F
@@ -30,19 +30,24 @@ class LinearMethodBase(ABC):
3030
"""Base class for different (maybe quantized) linear methods."""
3131

3232
@abstractmethod
33-
def create_weights(self, input_size_per_partition: int,
33+
def create_weights(self, layer: torch.nn.Module,
34+
input_size_per_partition: int,
3435
output_size_per_partition: int, input_size: int,
35-
output_size: int,
36-
params_dtype: torch.dtype) -> Dict[str, Any]:
37-
"""Create weights for a linear layer."""
36+
output_size: int, params_dtype: torch.dtype,
37+
**extra_weight_attrs):
38+
"""Create weights for a linear layer.
39+
40+
The weights will be set as attributes of the layer."""
3841
raise NotImplementedError
3942

4043
@abstractmethod
4144
def apply_weights(self,
42-
weights: Dict[str, torch.Tensor],
45+
layer: torch.nn.Module,
4346
x: torch.Tensor,
4447
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
45-
"""Apply the weights to the input tensor."""
48+
"""Apply the weights in layer to the input tensor.
49+
50+
Expects create_weights to have been called before on the layer."""
4651
raise NotImplementedError
4752

4853

@@ -57,22 +62,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
5762
def __init__(self, separate_bias_add: bool = False):
5863
self.separate_bias_add = separate_bias_add
5964

60-
def create_weights(self, input_size_per_partition: int,
65+
def create_weights(self, layer: torch.nn.Module,
66+
input_size_per_partition: int,
6167
output_size_per_partition: int, input_size: int,
62-
output_size: int,
63-
params_dtype: torch.dtype) -> Dict[str, Any]:
68+
output_size: int, params_dtype: torch.dtype,
69+
**extra_weight_attrs):
6470
weight = Parameter(torch.empty(output_size_per_partition,
6571
input_size_per_partition,
6672
dtype=params_dtype),
6773
requires_grad=False)
6874
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
69-
return {"weight": weight}
75+
layer.register_parameter("weight", weight)
76+
set_weight_attrs(weight, extra_weight_attrs)
7077

7178
def apply_weights(self,
72-
weights: Dict[str, torch.Tensor],
79+
layer: torch.nn.Module,
7380
x: torch.Tensor,
7481
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
75-
weight = weights["weight"]
82+
weight = layer.weight
7683
if self.separate_bias_add:
7784
if bias is not None:
7885
return F.linear(x, weight) + bias
@@ -113,12 +120,9 @@ def __init__(
113120
if linear_method is None:
114121
linear_method = UnquantizedLinearMethod()
115122
self.linear_method = linear_method
116-
self.linear_weights = self.linear_method.create_weights(
117-
self.input_size, self.output_size, self.input_size,
118-
self.output_size, self.params_dtype)
119-
for name, weight in self.linear_weights.items():
120-
if isinstance(weight, torch.Tensor):
121-
self.register_parameter(name, weight)
123+
self.linear_method.create_weights(self, self.input_size,
124+
self.output_size, self.input_size,
125+
self.output_size, self.params_dtype)
122126
if bias:
123127
self.bias = Parameter(
124128
torch.empty(self.output_size, dtype=self.params_dtype))
@@ -128,7 +132,7 @@ def __init__(
128132

129133
def forward(self, x: torch.Tensor) -> torch.Tensor:
130134
bias = self.bias if not self.skip_bias_add else None
131-
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
135+
output = self.linear_method.apply_weights(self, x, bias)
132136
output_bias = self.bias if self.skip_bias_add else None
133137
return output, output_bias
134138

@@ -179,13 +183,13 @@ def __init__(
179183
if linear_method is None:
180184
linear_method = UnquantizedLinearMethod()
181185
self.linear_method = linear_method
182-
self.linear_weights = self.linear_method.create_weights(
183-
self.input_size, self.output_size_per_partition, self.input_size,
184-
self.output_size, self.params_dtype)
185-
for name, weight in self.linear_weights.items():
186-
if isinstance(weight, torch.Tensor):
187-
self.register_parameter(name, weight)
188-
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
186+
self.linear_method.create_weights(self,
187+
self.input_size,
188+
self.output_size_per_partition,
189+
self.input_size,
190+
self.output_size,
191+
self.params_dtype,
192+
weight_loader=self.weight_loader)
189193
if bias:
190194
self.bias = Parameter(
191195
torch.empty(self.output_size_per_partition,
@@ -217,8 +221,7 @@ def forward(self, input_):
217221
bias = self.bias if not self.skip_bias_add else None
218222

219223
# Matrix multiply.
220-
output_parallel = self.linear_method.apply_weights(
221-
self.linear_weights, input_, bias)
224+
output_parallel = self.linear_method.apply_weights(self, input_, bias)
222225
if self.gather_output:
223226
# All-gather across the partitions.
224227
output = tensor_model_parallel_all_gather(output_parallel)
@@ -554,13 +557,13 @@ def __init__(
554557
if linear_method is None:
555558
linear_method = UnquantizedLinearMethod()
556559
self.linear_method = linear_method
557-
self.linear_weights = self.linear_method.create_weights(
558-
self.input_size_per_partition, self.output_size, self.input_size,
559-
self.output_size, self.params_dtype)
560-
for name, weight in self.linear_weights.items():
561-
if isinstance(weight, torch.Tensor):
562-
self.register_parameter(name, weight)
563-
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
560+
self.linear_method.create_weights(self,
561+
self.input_size_per_partition,
562+
self.output_size,
563+
self.input_size,
564+
self.output_size,
565+
self.params_dtype,
566+
weight_loader=self.weight_loader)
564567

565568
if not reduce_results and (bias and not skip_bias_add):
566569
raise ValueError("When not reduce the results, adding bias to the "
@@ -604,7 +607,7 @@ def forward(self, input_):
604607

605608
# Matrix multiply.
606609
output_parallel = self.linear_method.apply_weights(
607-
self.linear_weights, input_parallel)
610+
self, input_parallel)
608611
if self.reduce_results and self.tp_size > 1:
609612
output_ = tensor_model_parallel_all_reduce(output_parallel)
610613
else:

vllm/model_executor/layers/quantization/awq.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
7979
def __init__(self, quant_config: AWQConfig):
8080
self.quant_config = quant_config
8181

82-
def create_weights(self, input_size_per_partition: int,
82+
def create_weights(self, layer: torch.nn.Module,
83+
input_size_per_partition: int,
8384
output_size_per_partition: int, input_size: int,
84-
output_size: int,
85-
params_dtype: torch.dtype) -> Dict[str, Any]:
85+
output_size: int, params_dtype: torch.dtype,
86+
**extra_weight_attrs):
8687
if input_size_per_partition % self.quant_config.group_size != 0:
8788
raise ValueError(
8889
"The input size is not aligned with the quantized "
@@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int,
136137
"input_dim": 0,
137138
"output_dim": 1,
138139
})
139-
return {
140-
"qweight": qweight,
141-
"qzeros": qzeros,
142-
"scales": scales,
143-
}
140+
141+
layer.register_parameter("qweight", qweight)
142+
set_weight_attrs(qweight, extra_weight_attrs)
143+
layer.register_parameter("qzeros", qzeros)
144+
set_weight_attrs(qzeros, extra_weight_attrs)
145+
layer.register_parameter("scales", scales)
146+
set_weight_attrs(scales, extra_weight_attrs)
144147

145148
def apply_weights(self,
146-
weights: Dict[str, Any],
149+
layer: torch.nn.Module,
147150
x: torch.Tensor,
148151
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
149-
qweight = weights["qweight"]
150-
scales = weights["scales"]
151-
qzeros = weights["qzeros"]
152+
qweight = layer.qweight
153+
scales = layer.scales
154+
qzeros = layer.qzeros
152155
pack_factor = self.quant_config.pack_factor
153156
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
154157
reshaped_x = x.reshape(-1, x.shape[-1])
@@ -163,5 +166,5 @@ def apply_weights(self,
163166
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
164167
pack_factor)
165168
if bias is not None:
166-
out = out + bias
169+
out.add_(bias)
167170
return out.reshape(out_shape)

vllm/model_executor/layers/quantization/gptq.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig):
8989

9090
def create_weights(
9191
self,
92+
layer: torch.nn.Module,
9293
input_size_per_partition: int,
9394
output_size_per_partition: int,
9495
input_size: int,
9596
output_size: int,
9697
params_dtype: torch.dtype,
97-
) -> Dict[str, Any]:
98+
**extra_weight_attrs,
99+
):
98100
del output_size # Unused.
99101
if input_size_per_partition % self.quant_config.group_size != 0:
100102
raise ValueError(
@@ -179,37 +181,40 @@ def create_weights(
179181
"input_dim": scale_and_zero_input_dim,
180182
"output_dim": 1,
181183
})
182-
return {
183-
"qweight": qweight,
184-
"g_idx": g_idx,
185-
"qzeros": qzeros,
186-
"scales": scales,
187-
"exllama_state": exllama_state,
188-
}
184+
185+
layer.register_parameter("qweight", qweight)
186+
set_weight_attrs(qweight, extra_weight_attrs)
187+
layer.register_parameter("g_idx", g_idx)
188+
set_weight_attrs(g_idx, extra_weight_attrs)
189+
layer.register_parameter("qzeros", qzeros)
190+
set_weight_attrs(qzeros, extra_weight_attrs)
191+
layer.register_parameter("scales", scales)
192+
set_weight_attrs(scales, extra_weight_attrs)
193+
194+
layer.exllama_state = exllama_state
189195

190196
def apply_weights(self,
191-
weights: Dict[str, Any],
197+
layer: torch.nn.Module,
192198
x: torch.Tensor,
193199
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
194-
qweight = weights["qweight"]
200+
qweight = layer.qweight
195201
out_shape = x.shape[:-1] + (qweight.shape[-1], )
196202
reshaped_x = x.reshape(-1, x.shape[-1])
197203
# exllama needs to shuffle the weight after the weight is loaded
198204
# here we do the shuffle on first forward pass
199-
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
205+
if layer.exllama_state == ExllamaState.UNINITIALIZED:
200206
if self.quant_config.desc_act:
201-
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
202-
torch.int)
207+
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
203208
else:
204-
weights["g_idx"] = torch.empty((1, 1), device="meta")
205-
weights["exllama_state"] = ExllamaState.READY
206-
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
209+
layer.g_idx.data = torch.empty((0, ),
210+
device=layer.g_idx.device)
211+
layer.exllama_state = ExllamaState.READY
212+
ops.gptq_shuffle(layer.qweight, layer.g_idx,
207213
self.quant_config.weight_bits)
208-
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
209-
weights["qzeros"], weights["scales"],
210-
weights["g_idx"],
211-
weights["exllama_state"] == ExllamaState.READY,
214+
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
215+
layer.scales, layer.g_idx,
216+
layer.exllama_state == ExllamaState.READY,
212217
self.quant_config.weight_bits)
213218
if bias is not None:
214-
output = output + bias
219+
output.add_(bias)
215220
return output.reshape(out_shape)

vllm/model_executor/layers/quantization/marlin.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig):
9191

9292
def create_weights(
9393
self,
94+
layer: torch.nn.Module,
9495
input_size_per_partition: int,
9596
output_size_per_partition: int,
9697
input_size: int,
9798
output_size: int,
9899
params_dtype: torch.dtype,
99-
) -> Dict[str, Any]:
100+
**extra_weight_attrs,
101+
):
100102
del output_size # Unused.
101103

102104
if params_dtype != torch.float16:
@@ -187,21 +189,22 @@ def create_weights(
187189
dtype=torch.int),
188190
requires_grad=False)
189191

190-
return {
191-
"B": qweight,
192-
"s": scales,
193-
"workspace": workspace,
194-
}
192+
layer.register_parameter("B", qweight)
193+
set_weight_attrs(qweight, extra_weight_attrs)
194+
layer.register_parameter("s", scales)
195+
set_weight_attrs(scales, extra_weight_attrs)
196+
layer.register_parameter("workspace", workspace)
197+
set_weight_attrs(workspace, extra_weight_attrs)
195198

196199
def apply_weights(
197200
self,
198-
weights: Dict[str, Any],
201+
layer: torch.nn.Module,
199202
x: torch.Tensor,
200203
bias: Optional[torch.Tensor] = None,
201204
) -> torch.Tensor:
202-
qweight = weights["B"]
203-
scales = weights["s"]
204-
workspace = weights["workspace"]
205+
qweight = layer.B
206+
scales = layer.s
207+
workspace = layer.workspace
205208

206209
x_2d = x.view(-1, x.shape[-1])
207210

0 commit comments

Comments
 (0)