@@ -646,48 +646,49 @@ def forward(
646646 def _load_from_state_dict (
647647 self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
648648 ):
649- # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
650- for hook in self ._load_state_dict_pre_hooks .values ():
651- hook (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
652-
653- persistent_buffers = {k : v for k , v in self ._buffers .items () if k not in self ._non_persistent_buffers_set }
654- local_name_params = itertools .chain (self ._parameters .items (), persistent_buffers .items ())
655- local_state = {k : v for k , v in local_name_params if v is not None }
656-
657- key = "qkv_weight"
658- k1 = "q_proj.weight"
659- k2 = "k_proj.weight"
660- k3 = "v_proj.weight"
661- q_w = state_dict [prefix + k1 ]
662- k_w = state_dict [prefix + k2 ]
663- v_w = state_dict [prefix + k3 ]
664-
665- device_mesh = self .helper_layout .device_mesh
666- sharding_spec = self .helper_layout .sharding_spec
667- q_w = distribute_tensor (q_w , device_mesh , sharding_spec )
668- k_w = distribute_tensor (k_w , device_mesh , sharding_spec )
669- v_w = distribute_tensor (v_w , device_mesh , sharding_spec )
670-
671- qkv_w = torch .stack ([q_w .T , k_w .T , v_w .T ], dim = 0 )
672-
673- input_param = nn .Parameter (
674- qkv_w
675- ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
676-
677- param = local_state [key ]
678-
679- try :
680- with torch .no_grad ():
681- param .copy_ (input_param )
682- except Exception as ex :
683- error_msgs .append (
684- 'While copying the parameter named "{}", '
685- "whose dimensions in the model are {} and "
686- "whose dimensions in the checkpoint are {}, "
687- "an exception occurred : {}." .format (key , param .size (), input_param .size (), ex .args )
688- )
649+ if self .num_heads == self .num_key_value_heads :
650+ # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
651+ for hook in self ._load_state_dict_pre_hooks .values ():
652+ hook (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
653+
654+ persistent_buffers = {k : v for k , v in self ._buffers .items () if k not in self ._non_persistent_buffers_set }
655+ local_name_params = itertools .chain (self ._parameters .items (), persistent_buffers .items ())
656+ local_state = {k : v for k , v in local_name_params if v is not None }
657+
658+ key = "qkv_weight"
659+ k1 = "q_proj.weight"
660+ k2 = "k_proj.weight"
661+ k3 = "v_proj.weight"
662+ q_w = state_dict [prefix + k1 ]
663+ k_w = state_dict [prefix + k2 ]
664+ v_w = state_dict [prefix + k3 ]
665+
666+ device_mesh = self .helper_layout .device_mesh
667+ sharding_spec = self .helper_layout .sharding_spec
668+ q_w = distribute_tensor (q_w , device_mesh , sharding_spec )
669+ k_w = distribute_tensor (k_w , device_mesh , sharding_spec )
670+ v_w = distribute_tensor (v_w , device_mesh , sharding_spec )
671+
672+ qkv_w = torch .stack ([q_w .T , k_w .T , v_w .T ], dim = 0 )
673+
674+ input_param = nn .Parameter (
675+ qkv_w
676+ ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
677+
678+ param = local_state [key ]
679+
680+ try :
681+ with torch .no_grad ():
682+ param .copy_ (input_param )
683+ except Exception as ex :
684+ error_msgs .append (
685+ 'While copying the parameter named "{}", '
686+ "whose dimensions in the model are {} and "
687+ "whose dimensions in the checkpoint are {}, "
688+ "an exception occurred : {}." .format (key , param .size (), input_param .size (), ex .args )
689+ )
689690
690- strict = False # to avoid unexpected_keys
691+ strict = False # to avoid unexpected_keys
691692 super ()._load_from_state_dict (
692693 state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
693694 )
0 commit comments