Skip to content

Commit

Permalink
[Misc] Remove Mixtral device="cuda" declarations (vllm-project#4543)
Browse files Browse the repository at this point in the history
Remove the device="cuda" declarations in mixtral as promised in vllm-project#4343
  • Loading branch information
pcmoritz authored May 1, 2024
1 parent f7cdf1f commit fccc696
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,11 @@ def __init__(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))

set_weight_attrs(self.ws, {
Expand All @@ -114,22 +112,20 @@ def __init__(

# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
self.w2s_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None

# Scaling factors for FP8 activations
need_act_scales = (self.use_fp8
and quant_config.activation_scheme == "static")
self.as_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
torch.zeros(1, dtype=torch.float32),
requires_grad=False) if need_act_scales else None
self.a2s_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
torch.zeros(1, dtype=torch.float32),
requires_grad=False) if need_act_scales else None

if need_act_scales:
Expand Down

0 comments on commit fccc696

Please sign in to comment.