Skip to content

Commit 21313e0

Browse files
authored
[Bugfix] Fix default weight loading for scalars (#7534)
1 parent f4da5f7 commit 21313e0

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,17 @@ def default_weight_loader(param: torch.Tensor,
516516
loaded_weight: torch.Tensor) -> None:
517517
"""Default weight loader."""
518518
try:
519-
assert param.size() == loaded_weight.size(), (
520-
f"Attempted to load weight ({loaded_weight.size()}) "
521-
f"into parameter ({param.size()})")
522-
523-
param.data.copy_(loaded_weight)
519+
if param.numel() == 1 and loaded_weight.numel() == 1:
520+
# Sometimes scalar values aren't considered tensors with shapes
521+
# so if both param and loaded_weight are a scalar,
522+
# "broadcast" instead of copy
523+
param.data.fill_(loaded_weight.item())
524+
else:
525+
assert param.size() == loaded_weight.size(), (
526+
f"Attempted to load weight ({loaded_weight.size()}) "
527+
f"into parameter ({param.size()})")
528+
529+
param.data.copy_(loaded_weight)
524530
except Exception:
525531
# NOTE: This exception is added for the purpose of setting breakpoint to
526532
# debug weight loading issues.

0 commit comments

Comments
 (0)