Skip to content

Commit

Permalink
[Bugfix] Fix dummy weight for fp8 (vllm-project#4916)
Browse files Browse the repository at this point in the history
Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment

Co-authored-by: Mor Zusman <morz@ai21.com>
  • Loading branch information
2 people authored and joerunde committed Jun 3, 2024
1 parent 292c4ec commit aa76008
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,11 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)

0 comments on commit aa76008

Please sign in to comment.