Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4910,11 +4910,10 @@ def from_pretrained(
if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
# TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
logger.warning(
"We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
"This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
"the context manager or global device with `from_config`, or `ModelClass(config)`"
raise RuntimeError(
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
)
device_map = device_in_context

Expand Down
4 changes: 0 additions & 4 deletions tests/models/perception_lm/test_modeling_perception_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,6 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
def test_can_be_initialized_on_meta(self):
pass

@unittest.skip("ViT PE / TimmWrapperModel cannot be tested with meta device")
def test_can_load_with_meta_device_context_manager(self):
pass

@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
def test_generate_from_inputs_embeds_0_greedy(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/models/timm_backbone/test_modeling_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_can_load_with_global_device_set(self):
pass

@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_can_load_with_meta_device_context_manager(self):
def test_cannot_load_with_meta_device_context_manager(self):
pass

@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
Expand Down
4 changes: 0 additions & 4 deletions tests/models/xcodec/test_modeling_xcodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def test_gradient_checkpointing_backward_compatibility(self):
model = model_class(config)
self.assertTrue(model.is_gradient_checkpointing)

@unittest.skip("XcodecModel cannot be tested with meta device")
def test_can_load_with_meta_device_context_manager(self):
pass

@unittest.skip(reason="We cannot configure to output a smaller model.")
def test_model_is_small(self):
pass
Expand Down
17 changes: 5 additions & 12 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4488,7 +4488,7 @@ def test_can_load_with_global_device_set(self):
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)

def test_can_load_with_meta_device_context_manager(self):
def test_cannot_load_with_meta_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
Expand All @@ -4497,18 +4497,11 @@ def test_can_load_with_meta_device_context_manager(self):

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

with torch.device("meta"):
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}

self.assertEqual(
unique_devices,
{torch.device("meta")},
f"All parameters should be on meta device, but found {unique_devices}.",
)
with self.assertRaisesRegex(
RuntimeError, "You are using `from_pretrained` with a meta device context manager"
):
_ = model_class.from_pretrained(tmpdirname)

def test_config_attn_implementation_setter(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down