|
| 1 | +from abc import abstractmethod |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from vllm.distributed import (get_tensor_model_parallel_rank, |
| 7 | + get_tensor_model_parallel_world_size, |
| 8 | + tensor_model_parallel_all_reduce) |
| 9 | +from vllm.logger import init_logger |
| 10 | +from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe |
| 11 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 12 | + QuantizationConfig, QuantizeMethodBase) |
| 13 | +from vllm.model_executor.utils import set_weight_attrs |
| 14 | + |
| 15 | +logger = init_logger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class FusedMoEMethodBase(QuantizeMethodBase): |
| 19 | + |
| 20 | + @abstractmethod |
| 21 | + def create_weights(self, layer: torch.nn.Module, num_experts: int, |
| 22 | + hidden_size: int, intermediate_size: int, |
| 23 | + params_dtype: torch.dtype, **extra_weight_attrs): |
| 24 | + raise NotImplementedError |
| 25 | + |
| 26 | + @abstractmethod |
| 27 | + def apply(self, |
| 28 | + layer: torch.nn.Module, |
| 29 | + x: torch.Tensor, |
| 30 | + router_logits: torch.Tensor, |
| 31 | + top_k: int, |
| 32 | + renormalize: bool = True) -> torch.Tensor: |
| 33 | + raise NotImplementedError |
| 34 | + |
| 35 | + |
| 36 | +class UnquantizedFusedMoEMethod(FusedMoEMethodBase): |
| 37 | + """MoE method without quantization.""" |
| 38 | + |
| 39 | + def create_weights(self, layer: torch.nn.Module, num_experts: int, |
| 40 | + hidden_size: int, intermediate_size: int, |
| 41 | + params_dtype: torch.dtype, **extra_weight_attrs): |
| 42 | + |
| 43 | + # Fused gate_up_proj (column parallel) |
| 44 | + w13_weight = torch.nn.Parameter(torch.empty(num_experts, |
| 45 | + 2 * intermediate_size, |
| 46 | + hidden_size, |
| 47 | + dtype=params_dtype), |
| 48 | + requires_grad=False) |
| 49 | + layer.register_parameter("w13_weight", w13_weight) |
| 50 | + set_weight_attrs(w13_weight, extra_weight_attrs) |
| 51 | + |
| 52 | + # down_proj (row parallel) |
| 53 | + w2_weight = torch.nn.Parameter(torch.empty(num_experts, |
| 54 | + hidden_size, |
| 55 | + intermediate_size, |
| 56 | + dtype=params_dtype), |
| 57 | + requires_grad=False) |
| 58 | + layer.register_parameter("w2_weight", w2_weight) |
| 59 | + set_weight_attrs(w2_weight, extra_weight_attrs) |
| 60 | + |
| 61 | + def apply(self, |
| 62 | + layer: torch.nn.Module, |
| 63 | + x: torch.Tensor, |
| 64 | + router_logits: torch.Tensor, |
| 65 | + top_k: int, |
| 66 | + renormalize: bool = True) -> torch.Tensor: |
| 67 | + |
| 68 | + return fused_moe(x, |
| 69 | + layer.w13_weight, |
| 70 | + layer.w2_weight, |
| 71 | + router_logits, |
| 72 | + top_k, |
| 73 | + renormalize=renormalize, |
| 74 | + inplace=True) |
| 75 | + |
| 76 | + |
| 77 | +class FusedMoE(torch.nn.Module): |
| 78 | + """FusedMoE layer for MoE models. |
| 79 | +
|
| 80 | + This layer contains both MergedColumnParallel weights (gate_up_proj / |
| 81 | + w13) and RowParallelLinear weights (down_proj/ w2). |
| 82 | +
|
| 83 | + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We |
| 84 | + copy that naming convention here and handle any remapping in the |
| 85 | + load_weights function in each model implementation. |
| 86 | +
|
| 87 | + Args: |
| 88 | + num_experts: Number of experts in the model |
| 89 | + top_k: Number of experts selected for each token |
| 90 | + hidden_size: Input hidden state size of the transformer |
| 91 | + intermediate_size: Intermediate size of the experts |
| 92 | + params_dtype: Data type for the parameters. |
| 93 | + reduce_results: Whether to all all_reduce on the output of the layer |
| 94 | + renomalize: Whether to renormalize the logits in the fused_moe kernel |
| 95 | + quant_config: Quantization configure. |
| 96 | + """ |
| 97 | + |
| 98 | + def __init__( |
| 99 | + self, |
| 100 | + num_experts: int, |
| 101 | + top_k: int, |
| 102 | + hidden_size: int, |
| 103 | + intermediate_size: int, |
| 104 | + params_dtype: Optional[torch.dtype] = None, |
| 105 | + reduce_results: bool = False, |
| 106 | + renormalize: bool = True, |
| 107 | + quant_config: Optional[QuantizationConfig] = None, |
| 108 | + tp_size: Optional[int] = None, |
| 109 | + ): |
| 110 | + super().__init__() |
| 111 | + |
| 112 | + if params_dtype is None: |
| 113 | + params_dtype = torch.get_default_dtype() |
| 114 | + |
| 115 | + self.tp_size = (tp_size if tp_size is not None else |
| 116 | + get_tensor_model_parallel_world_size()) |
| 117 | + self.top_k = top_k |
| 118 | + self.num_experts = num_experts |
| 119 | + self.intermediate_size_per_partition = intermediate_size // self.tp_size |
| 120 | + self.reduce_results = reduce_results |
| 121 | + self.renormalize = renormalize |
| 122 | + |
| 123 | + if quant_config is None: |
| 124 | + self.quant_method: Optional[QuantizeMethodBase] = ( |
| 125 | + UnquantizedFusedMoEMethod()) |
| 126 | + else: |
| 127 | + self.quant_method = quant_config.get_quant_method(self) |
| 128 | + assert self.quant_method is not None |
| 129 | + |
| 130 | + self.quant_method.create_weights( |
| 131 | + layer=self, |
| 132 | + num_experts=num_experts, |
| 133 | + hidden_size=hidden_size, |
| 134 | + intermediate_size=self.intermediate_size_per_partition, |
| 135 | + params_dtype=params_dtype, |
| 136 | + weight_loader=self.weight_loader) |
| 137 | + |
| 138 | + def weight_loader(self, param: torch.nn.Parameter, |
| 139 | + loaded_weight: torch.Tensor, weight_name: str, |
| 140 | + shard_id: int, expert_id: int): |
| 141 | + param_data = param.data |
| 142 | + |
| 143 | + # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. |
| 144 | + # Follow up PR to enable fp8 for other MoE models. |
| 145 | + if "input_scale" in weight_name or "w2.weight_scale" in weight_name: |
| 146 | + if param_data[expert_id] != 1 and (param_data[expert_id] - |
| 147 | + loaded_weight).abs() > 1e-5: |
| 148 | + raise ValueError( |
| 149 | + "input_scales of w1 and w3 of a layer " |
| 150 | + f"must be equal. But got {param_data[expert_id]} " |
| 151 | + f"vs. {loaded_weight}") |
| 152 | + param_data[expert_id] = loaded_weight |
| 153 | + # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. |
| 154 | + # Follow up PR to enable fp8 for other MoE models. |
| 155 | + elif "weight_scale" in weight_name: |
| 156 | + # We have to keep the weight scales of w1 and w3 because |
| 157 | + # we need to re-quantize w1/w3 weights after weight loading. |
| 158 | + assert "w1" in weight_name or "w3" in weight_name |
| 159 | + shard_id = 0 if "w1" in weight_name else 1 |
| 160 | + param_data[expert_id][shard_id] = loaded_weight |
| 161 | + else: |
| 162 | + tp_rank = get_tensor_model_parallel_rank() |
| 163 | + shard_size = self.intermediate_size_per_partition |
| 164 | + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) |
| 165 | + |
| 166 | + # w1, gate_proj case: Load into first shard of w13. |
| 167 | + if shard_id == 0: |
| 168 | + param_data[expert_id, |
| 169 | + 0:shard_size, :] = loaded_weight[shard, :] |
| 170 | + # w3, up_proj case: Load into second shard of w13. |
| 171 | + elif shard_id == 2: |
| 172 | + param_data[expert_id, shard_size:2 * |
| 173 | + shard_size, :] = loaded_weight[shard, :] |
| 174 | + # w2, down_proj case: Load into only shard of w2. |
| 175 | + elif shard_id == 1: |
| 176 | + param_data[expert_id, :, :] = loaded_weight[:, shard] |
| 177 | + else: |
| 178 | + raise ValueError( |
| 179 | + f"Shard id must be in [0,1,2] but got {shard_id}") |
| 180 | + |
| 181 | + def forward(self, hidden_states: torch.Tensor, |
| 182 | + router_logits: torch.Tensor): |
| 183 | + assert self.quant_method is not None |
| 184 | + |
| 185 | + # Matrix multiply. |
| 186 | + final_hidden_states = self.quant_method.apply( |
| 187 | + self, |
| 188 | + x=hidden_states, |
| 189 | + router_logits=router_logits, |
| 190 | + top_k=self.top_k, |
| 191 | + renormalize=self.renormalize) |
| 192 | + |
| 193 | + if self.reduce_results and self.tp_size > 1: |
| 194 | + final_hidden_states = tensor_model_parallel_all_reduce( |
| 195 | + final_hidden_states) |
| 196 | + |
| 197 | + return final_hidden_states |
0 commit comments