Skip to content

Commit

Permalink
enable mixtral quantization using INC (vllm-project#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester authored Oct 15, 2024
1 parent 57bc31d commit 55dd07e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 65 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@36c7f9c
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7531cc6
3 changes: 3 additions & 0 deletions vllm/executor/ray_hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def shutdown(self) -> None:
ray.kill(worker)
self.forward_dag = None

def finish_measurements(self):
self._run_workers("finish_measurements")

def _get_worker_module_and_class(
self
) -> Tuple[str, str, Optional[Callable[[],
Expand Down
97 changes: 33 additions & 64 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

is_hpu = current_platform.is_hpu()

logger = init_logger(__name__)


Expand Down Expand Up @@ -262,21 +264,23 @@ def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
tp_rank: int, expert_id: int):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
expert_id=expert_id)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
expert_id=expert_id)

def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
Expand All @@ -292,9 +296,15 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
expert_data=expert_data,
tp_rank=tp_rank)

def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int,
expert_id: Optional[int] = None):

orig_exp_data = expert_data.view(expert_data.size())
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
Expand All @@ -310,8 +320,17 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
if is_hpu:
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
orig_exp_data)

def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int,
expert_id: Optional[int] = None):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
Expand All @@ -321,6 +340,9 @@ def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
if is_hpu:
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(
expert_data)

def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
Expand Down Expand Up @@ -423,7 +445,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
expert_id=expert_id)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
Expand All @@ -449,7 +472,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
expert_id=expert_id)
return

@staticmethod
Expand Down Expand Up @@ -528,58 +552,3 @@ def make_expert_params_mapping(
("w3", ckpt_up_proj_name),
]
]

def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data

# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
# Weights
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)

# w1, gate_proj case: Load into first shard of w13.
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
if current_platform.is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
if current_platform.is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if current_platform.is_hpu():
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(
param_data[expert_id])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")

0 comments on commit 55dd07e

Please sign in to comment.