Skip to content

[RFC][Refactor] Generalize linear_method to be quant_method #4342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod)
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
4 changes: 2 additions & 2 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance.deserialize.return_value = MagicMock()

result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method)
quant_method=mock_linear_method)

mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method)
quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value

Expand Down
30 changes: 13 additions & 17 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,9 @@ def set_mapping(
self.indices = base_indices
self.indices_len = indices_len

def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
Expand All @@ -411,7 +410,7 @@ def forward(self, input_):
if not self.base_layer.skip_bias_add else None)

# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
Expand Down Expand Up @@ -517,10 +516,9 @@ def set_lora(
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)

def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -758,10 +756,9 @@ def set_lora(
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)

def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -854,9 +851,8 @@ def set_mapping(
self.indices = base_indices
self.indices_len = indices_len

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -889,7 +885,7 @@ def forward(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()

# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
from vllm.model_executor.layers.fused_moe.quant_methods import (
MoEMethodBase, UnquantizedMoEMethod)

__all__ = [
"fused_moe",
"get_config_file_name",
"MoEMethodBase",
"UnquantizedMoEMethod",
]
72 changes: 72 additions & 0 deletions vllm/model_executor/layers/fused_moe/quant_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from abc import abstractmethod

import torch
from torch import nn

from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs


class MoEMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) MoE methods."""

@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_total_experts: int,
intermediate_size: int, hidden_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.

Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError

@abstractmethod
def apply(self, layer: torch.nn.Module, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.

Expects create_weights to have been called before on the layer."""
raise NotImplementedError


class UnquantizedMoEMethod(MoEMethodBase):
"""MoE method without quantization."""

def create_weights(self, layer: torch.nn.Module, num_total_experts: int,
intermediate_size: int, hidden_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
ws = nn.Parameter(
torch.empty(num_total_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype))
w2s = nn.Parameter(
torch.empty(num_total_experts,
hidden_size,
intermediate_size,
dtype=params_dtype))
layer.register_parameter("ws", ws)
layer.register_parameter("w2s", w2s)
set_weight_attrs(ws, extra_weight_attrs)
set_weight_attrs(w2s, extra_weight_attrs)

def apply(self, layer: torch.nn.Module, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> torch.Tensor:
return fused_moe(hidden_states,
layer.ws,
layer.w2s,
router_logits,
layer.top_k,
renormalize=True,
inplace=True)
Loading