diff --git a/__init__.py b/__init__.py index b76c6b07..d056efdf 100644 --- a/__init__.py +++ b/__init__.py @@ -5,7 +5,9 @@ datasets, etc. """ +from typing import TYPE_CHECKING as _TYPE_CHECKING import sys as _sys +from ._lazy_loader import LazyLoader as _LazyLoader # We require at least Python 3.7. # See https://github.com/rwth-i6/returnn_common/issues/43. @@ -13,3 +15,11 @@ # - Our code expects that the order of dict is deterministic, or even insertion order specifically. # - Type annotations. assert _sys.version_info[:2] >= (3, 7) + +# Now all the imports. +# Use lazy imports, but only when not type checking. +if _TYPE_CHECKING: + from . import nn # noqa + +else: + nn = _LazyLoader("nn", globals()) diff --git a/_lazy_loader.py b/_lazy_loader.py new file mode 100644 index 00000000..c4bdc33c --- /dev/null +++ b/_lazy_loader.py @@ -0,0 +1,42 @@ + +""" +Lazy module loader. +Code adapted from TensorFlow. +""" + +import importlib +import types +from typing import Dict, Any + + +class LazyLoader(types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies. + """ + + def __init__(self, local_name: str, parent_module_globals: Dict[str, Any]): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + name = f'{parent_module_globals["__package__"]}.{local_name}' + super(LazyLoader, self).__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on lookups + # that fail). + self.__dict__.update(module.__dict__) + + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + def __dir__(self): + module = self._load() + return dir(module) +