File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed
vllm/model_executor/model_loader Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -516,11 +516,17 @@ def default_weight_loader(param: torch.Tensor,
516516 loaded_weight : torch .Tensor ) -> None :
517517 """Default weight loader."""
518518 try :
519- assert param .size () == loaded_weight .size (), (
520- f"Attempted to load weight ({ loaded_weight .size ()} ) "
521- f"into parameter ({ param .size ()} )" )
522-
523- param .data .copy_ (loaded_weight )
519+ if param .numel () == 1 and loaded_weight .numel () == 1 :
520+ # Sometimes scalar values aren't considered tensors with shapes
521+ # so if both param and loaded_weight are a scalar,
522+ # "broadcast" instead of copy
523+ param .data .fill_ (loaded_weight .item ())
524+ else :
525+ assert param .size () == loaded_weight .size (), (
526+ f"Attempted to load weight ({ loaded_weight .size ()} ) "
527+ f"into parameter ({ param .size ()} )" )
528+
529+ param .data .copy_ (loaded_weight )
524530 except Exception :
525531 # NOTE: This exception is added for the purpose of setting breakpoint to
526532 # debug weight loading issues.
You can’t perform that action at this time.
0 commit comments