diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c1abde9af7701..a1642baa2c90c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -369,4 +369,11 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): - param.data.uniform_(low, high) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high)