From fccc69643230167d194c629954ef09b085b2be29 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 1 May 2024 16:30:52 -0700 Subject: [PATCH] [Misc] Remove Mixtral device="cuda" declarations (#4543) Remove the device="cuda" declarations in mixtral as promised in #4343 --- vllm/model_executor/models/mixtral.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c5dd1a63e2f7a..9ff9ba298588a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -96,13 +96,11 @@ def __init__( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda", dtype=self.params_dtype)) set_weight_attrs(self.ws, { @@ -114,22 +112,20 @@ def __init__( # Scaling factors for FP8 weights self.ws_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), + torch.ones(self.num_total_experts, dtype=torch.float32), requires_grad=False) if self.use_fp8 else None self.w2s_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), + torch.ones(self.num_total_experts, dtype=torch.float32), requires_grad=False) if self.use_fp8 else None # Scaling factors for FP8 activations need_act_scales = (self.use_fp8 and quant_config.activation_scheme == "static") self.as_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), requires_grad=False) if need_act_scales else None self.a2s_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), requires_grad=False) if need_act_scales else None if need_act_scales: