Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,9 +1948,9 @@ def _can_set_attn_implementation(cls) -> bool:
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
opening the file, but avoids maintaining yet another property flag.
"""
class_module = sys.modules[cls.__module__]
class_module = sys.modules.get(cls.__module__)
# This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
if not hasattr(class_module, "__file__"):
if class_module is None or not hasattr(class_module, "__file__"):
return False
class_file = class_module.__file__
with open(class_file, "r", encoding="utf-8") as f:
Expand All @@ -1967,9 +1967,9 @@ def _can_set_experts_implementation(cls) -> bool:
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
opening the file, but avoids maintaining yet another property flag.
"""
class_module = sys.modules[cls.__module__]
class_module = sys.modules.get(cls.__module__)
# This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
if not hasattr(class_module, "__file__"):
if class_module is None or not hasattr(class_module, "__file__"):
return False
class_file = class_module.__file__
with open(class_file, "r", encoding="utf-8") as f:
Expand Down
27 changes: 27 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,33 @@ def test_identical(self):
self.assertEqual(identical_names, [])


@require_torch
class TestSysModulesMissing(unittest.TestCase):
"""Regression test for #45003: KeyError when module absent from sys.modules."""

def test_can_set_attn_impl_missing_module(self):
from transformers.models.bert.modeling_bert import BertModel

key = BertModel.__module__
saved = sys.modules.pop(key, None)
try:
self.assertFalse(BertModel._can_set_attn_implementation())
finally:
if saved is not None:
sys.modules[key] = saved

def test_can_set_experts_impl_missing_module(self):
from transformers.models.bert.modeling_bert import BertModel

key = BertModel.__module__
saved = sys.modules.pop(key, None)
try:
self.assertFalse(BertModel._can_set_experts_implementation())
finally:
if saved is not None:
sys.modules[key] = saved


@require_torch
class TestSaveAndLoadModelWithExtraState(TestCasePlus):
"""
Expand Down
Loading