Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 201 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5113,12 +5113,203 @@ def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool)
be initialized correctly (i.e. weight initialization distribution).
Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
"""
for key in self.state_dict():
missing_keys_set = set(missing_keys)

model_state_dict_keys = set(self.state_dict().keys())

if missing_keys_set and missing_keys_set >= model_state_dict_keys:
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = list(self.state_dict(keep_vars=True).values())
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self.initialize_weights()
else:
self.initialize_weights()
return

for key in model_state_dict_keys:
# If it's part of the keys that will be loaded, mark it as already initialized
if key not in missing_keys:
if key not in missing_keys_set:
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True

handled_missing_keys: set[str] = set()

if missing_keys_set and not is_quantized:
missing_params_by_module: defaultdict[str, set[str]] = defaultdict(set)
missing_buffers_by_module: defaultdict[str, set[str]] = defaultdict(set)

for key in missing_keys_set:
if "." not in key:
continue
module_path, name = key.rsplit(".", 1)
if module_path == "":
continue
try:
module = self.get_submodule(module_path)
except AttributeError:
continue

parameters = getattr(module, "_parameters", {})
if name in parameters and parameters[name] is not None:
missing_params_by_module[module_path].add(name)
continue

buffers = getattr(module, "_buffers", {})
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
if name in buffers and buffers[name] is not None and name not in non_persistent:
missing_buffers_by_module[module_path].add(name)

# Sort by depth (deepest first) so child modules are handled before their parents.
module_paths = sorted(
set(missing_params_by_module.keys()) | set(missing_buffers_by_module.keys()),
key=lambda name: name.count("."),
reverse=True,
Comment on lines +5165 to +5168
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sorting logic (by dot count in reverse order) is important for proper parent-child module initialization but lacks a comment explaining why modules must be processed from deepest to shallowest. Add a comment to clarify that this ensures child modules are initialized before parents.

Copilot uses AI. Check for mistakes.
)

modules_info: list[tuple[str, nn.Module, set[str], set[str]]] = []
for module_path in module_paths:
try:
module = self.get_submodule(module_path)
except AttributeError:
continue
modules_info.append(
(
module_path,
module,
missing_params_by_module.get(module_path, set()),
missing_buffers_by_module.get(module_path, set()),
)
)

if modules_info:

def _initialize_modules():
with torch.no_grad():
for module_path, module, module_missing_params, module_missing_buffers in modules_info:
immediate_params = dict(module.named_parameters(recurse=False))
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
persistent_buffers = {
name: buffer
for name, buffer in module.named_buffers(recurse=False)
if name not in non_persistent
}

already_initialized_params = {
name
for name in module_missing_params
if name in immediate_params
and getattr(immediate_params[name], "_is_hf_initialized", False)
}
already_initialized_buffers = {
name
for name in module_missing_buffers
if name in persistent_buffers
and getattr(persistent_buffers[name], "_is_hf_initialized", False)
}
if already_initialized_params or already_initialized_buffers:
handled_missing_keys.update(
{f"{module_path}.{name}" for name in already_initialized_params}
)
handled_missing_keys.update(
{f"{module_path}.{name}" for name in already_initialized_buffers}
)

missing_params = {
name
for name in module_missing_params
if name in immediate_params
and not getattr(immediate_params[name], "_is_hf_initialized", False)
}
missing_buffers = {
name
for name in module_missing_buffers
if name in persistent_buffers
and not getattr(persistent_buffers[name], "_is_hf_initialized", False)
}

if not missing_params and not missing_buffers:
continue

all_param_names = set(immediate_params.keys())
all_buffer_names = set(persistent_buffers.keys())
# If every immediate parameter and buffer is absent, recreate the whole module.
fully_missing = (not all_param_names or missing_params == all_param_names) and (
not all_buffer_names or missing_buffers == all_buffer_names
)

if fully_missing:
self._initialize_weights(module)
else:
if is_deepspeed_zero3_enabled():
import deepspeed

preserved_parameters = {}
for name, param in immediate_params.items():
if name in missing_params:
continue
with deepspeed.zero.GatheredParameters([param], modifier_rank=None):
preserved_parameters[name] = param.detach().clone()

preserved_buffers = {}
for name, buffer in persistent_buffers.items():
if name in missing_buffers or buffer is None:
continue
with deepspeed.zero.GatheredParameters([buffer], modifier_rank=None):
preserved_buffers[name] = buffer.detach().clone()
else:
preserved_parameters = {
name: param.detach().clone()
for name, param in immediate_params.items()
if name not in missing_params
}
preserved_buffers = {
name: buffer.detach().clone()
for name, buffer in persistent_buffers.items()
if name not in missing_buffers and buffer is not None
}

self._initialize_weights(module)

for name, tensor in preserved_parameters.items():
module._parameters[name].data.copy_(tensor)
module._parameters[name]._is_hf_initialized = True

for name, tensor in preserved_buffers.items():
buffer = module._buffers[name]
buffer.data.copy_(tensor)
buffer._is_hf_initialized = True

for name in missing_params:
param = module._parameters.get(name)
if param is not None:
param._is_hf_initialized = True
handled_missing_keys.add(f"{module_path}.{name}")

for name in missing_buffers:
buffer = module._buffers.get(name)
if buffer is not None:
buffer._is_hf_initialized = True
handled_missing_keys.add(f"{module_path}.{name}")

if is_deepspeed_zero3_enabled():
import deepspeed

params_to_gather = []
for _, module, _, _ in modules_info:
params_to_gather.extend(list(module.parameters(recurse=False)))

if params_to_gather:
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
_initialize_modules()
else:
_initialize_modules()
else:
_initialize_modules()

missing_keys_set -= handled_missing_keys

def set_is_initialized_for_modules(module):
# A module is already initialized if and only if all its children are also already initialized, and all
# its immediate `nn.Parameter` and persistent buffers are also already initialized
Expand All @@ -5141,14 +5332,17 @@ def set_is_initialized_for_modules(module):
# each param)
self.apply(set_is_initialized_for_modules)

# This will only initialize submodules that are not marked as initialized by the line above.
not_initialized_parameters = list(
{v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
)

if not not_initialized_parameters:
return

# This will only initialize submodules that are not marked as initialized by the logic above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
not_initialized_parameters = list(
{v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
self.initialize_weights()
else:
Expand Down
104 changes: 104 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
CONFIG_NAME,
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
Expand Down Expand Up @@ -3898,6 +3899,109 @@ def test_bc_torch_dtype(self):
self.assertTrue((v1 == v2).all())


class TestMissingKeyInitialization:
@require_torch
def test_missing_linear_bias_does_not_override_weight(self):
from transformers import BertConfig, BertForSequenceClassification

config = BertConfig(
hidden_size=8, num_attention_heads=2, num_hidden_layers=1, num_labels=2, intermediate_size=16
)
model = BertForSequenceClassification(config)
with torch.no_grad():
sentinel = torch.arange(model.classifier.weight.numel(), dtype=torch.float32).view_as(
model.classifier.weight
)
model.classifier.weight.copy_(sentinel)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)
weights_file = os.path.join(tmpdirname, WEIGHTS_NAME)

state_dict = torch.load(weights_file)
expected_weight = state_dict["classifier.weight"].clone()
state_dict.pop("classifier.bias")
torch.save(state_dict, weights_file)

reloaded = BertForSequenceClassification.from_pretrained(tmpdirname)

torch.testing.assert_close(reloaded.classifier.weight, expected_weight)
torch.testing.assert_close(reloaded.classifier.bias, torch.zeros_like(reloaded.classifier.bias))

@require_torch
def test_missing_tied_decoder_weight_preserves_embeddings(self):
from transformers import BertConfig, BertForMaskedLM, BertModel

config = BertConfig(
hidden_size=8, num_attention_heads=2, num_hidden_layers=1, num_labels=2, intermediate_size=16
)
base_model = BertModel(config)
with torch.no_grad():
sentinel = torch.arange(base_model.embeddings.word_embeddings.weight.numel(), dtype=torch.float32).view_as(
base_model.embeddings.word_embeddings.weight
)
base_model.embeddings.word_embeddings.weight.copy_(sentinel)
expected_embeddings = base_model.embeddings.word_embeddings.weight.detach().clone()

with tempfile.TemporaryDirectory() as tmpdirname:
base_model.save_pretrained(tmpdirname)
mlm_model = BertForMaskedLM.from_pretrained(tmpdirname)

decoder_weight = mlm_model.cls.predictions.decoder.weight
word_embeddings = mlm_model.bert.embeddings.word_embeddings.weight

torch.testing.assert_close(decoder_weight, expected_embeddings)
torch.testing.assert_close(word_embeddings, expected_embeddings)
assert decoder_weight.data_ptr() == word_embeddings.data_ptr()

@require_torch
def test_missing_linear_bias_preserves_weight_custom_model(self):
import torch
from torch import nn

from transformers import PreTrainedConfig, PreTrainedModel

class _CustomConfig(PreTrainedConfig):
model_type = "custom-missing-bias"

def __init__(self, hidden_size=4, **kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size

class _CustomModel(PreTrainedModel):
config_class = _CustomConfig

def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(config.hidden_size, config.hidden_size)
self.post_init()

def forward(self, x):
return self.linear(x)

config = _CustomConfig(hidden_size=4)
model = _CustomModel(config)
with torch.no_grad():
sentinel = torch.arange(model.linear.weight.numel(), dtype=torch.float32).view_as(model.linear.weight)
model.linear.weight.copy_(sentinel)
model.linear.bias.zero_()

reloaded = _CustomModel(config)
with torch.no_grad():
reloaded.linear.weight.copy_(sentinel)
reloaded.linear.bias.zero_()

reloaded.linear.weight._is_hf_initialized = True
if hasattr(reloaded.linear.bias, "_is_hf_initialized"):
delattr(reloaded.linear.bias, "_is_hf_initialized")
reloaded.linear._is_hf_initialized = False

reloaded._initialize_missing_keys(["linear.bias"], is_quantized=False)

torch.testing.assert_close(reloaded.linear.weight, sentinel)
torch.testing.assert_close(reloaded.linear.bias, torch.zeros_like(reloaded.linear.bias))


global_rng = random.Random()


Expand Down
Loading