Skip to content

Commit

Permalink
[typing] clean up prefect __init__.py
Browse files Browse the repository at this point in the history
- Provide more typing details for version info and the local module spec.
- A module `__getattr__` is never called for globals, so the `_slots`
  structure was redundant.
- Minimise what code is run in the `try: ... except:` block.
- For Python 3.10 and up, include more context in the AttributeError.
  • Loading branch information
mjpieters committed Dec 19, 2024
1 parent 6343036 commit 34faa7b
Showing 1 changed file with 60 additions and 60 deletions.
120 changes: 60 additions & 60 deletions src/prefect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,14 @@

# Setup version and path constants

import sys
from . import _version
import importlib
import pathlib
from typing import TYPE_CHECKING, Any

__version_info__ = _version.get_versions()
__version__ = __version_info__["version"]

# The absolute path to this module
__module_path__ = pathlib.Path(__file__).parent
# The absolute path to the root of the repository, only valid for use during development
__development_base_path__ = __module_path__.parents[1]

# The absolute path to the built UI within the Python module, used by
# `prefect server start` to serve a dynamic build of the UI
__ui_static_subpath__ = __module_path__ / "server" / "ui_build"

# The absolute path to the built UI within the Python module
__ui_static_path__ = __module_path__ / "server" / "ui"

del _version, pathlib
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast

if TYPE_CHECKING:
from importlib.machinery import ModuleSpec
from .main import (
allow_failure,
flow,
Expand All @@ -45,80 +30,95 @@
suspend_flow_run,
)

_slots: dict[str, Any] = {
"__version_info__": __version_info__,
"__version__": __version__,
"__module_path__": __module_path__,
"__development_base_path__": __development_base_path__,
"__ui_static_subpath__": __ui_static_subpath__,
"__ui_static_path__": __ui_static_path__,
}
__spec__: ModuleSpec

# Versioneer provides version information as dictionaries
# with these keys
class VersionInfo(TypedDict("_FullRevisionId", {"full-revisionid": str})):
version: str
dirty: Optional[bool]
error: Optional[str]
date: Optional[str]


__version_info__: "VersionInfo" = cast("VersionInfo", _version.get_versions())
__version__ = __version_info__["version"]

_public_api: dict[str, tuple[str, str]] = {
# The absolute path to this module
__module_path__: pathlib.Path = pathlib.Path(__file__).parent
# The absolute path to the root of the repository, only valid for use during development
__development_base_path__: pathlib.Path = __module_path__.parents[1]

# The absolute path to the built UI within the Python module, used by
# `prefect server start` to serve a dynamic build of the UI
__ui_static_subpath__: pathlib.Path = __module_path__ / "server" / "ui_build"

# The absolute path to the built UI within the Python module
__ui_static_path__: pathlib.Path = __module_path__ / "server" / "ui"

del _version, pathlib

_public_api: dict[str, tuple[Optional[str], str]] = {
"allow_failure": (__spec__.parent, ".main"),
"aserve": (__spec__.parent, ".main"),
"deploy": (__spec__.parent, ".main"),
"flow": (__spec__.parent, ".main"),
"Flow": (__spec__.parent, ".main"),
"get_client": (__spec__.parent, ".main"),
"get_run_logger": (__spec__.parent, ".main"),
"pause_flow_run": (__spec__.parent, ".main"),
"resume_flow_run": (__spec__.parent, ".main"),
"serve": (__spec__.parent, ".main"),
"State": (__spec__.parent, ".main"),
"suspend_flow_run": (__spec__.parent, ".main"),
"tags": (__spec__.parent, ".main"),
"task": (__spec__.parent, ".main"),
"Task": (__spec__.parent, ".main"),
"Transaction": (__spec__.parent, ".main"),
"unmapped": (__spec__.parent, ".main"),
"serve": (__spec__.parent, ".main"),
"aserve": (__spec__.parent, ".main"),
"deploy": (__spec__.parent, ".main"),
"pause_flow_run": (__spec__.parent, ".main"),
"resume_flow_run": (__spec__.parent, ".main"),
"suspend_flow_run": (__spec__.parent, ".main"),
}

# Declare API for type-checkers
__all__ = [
"__development_base_path__",
"__module_path__",
"__ui_static_path__",
"__ui_static_subpath__",
"__version__",
"__version_info__",
"allow_failure",
"aserve",
"deploy",
"flow",
"Flow",
"get_client",
"get_run_logger",
"pause_flow_run",
"resume_flow_run",
"serve",
"State",
"suspend_flow_run",
"tags",
"task",
"Task",
"Transaction",
"unmapped",
"serve",
"aserve",
"deploy",
"pause_flow_run",
"resume_flow_run",
"suspend_flow_run",
"__version_info__",
"__version__",
"__module_path__",
"__development_base_path__",
"__ui_static_subpath__",
"__ui_static_path__",
]


def __getattr__(attr_name: str) -> object:
if attr_name in _slots:
return _slots[attr_name]
try:
dynamic_attr = _public_api.get(attr_name)
if dynamic_attr is None:
return importlib.import_module(f".{attr_name}", package=__name__)

package, module_name = dynamic_attr
def __getattr__(attr_name: str) -> Any:
if (dynamic_attr := _public_api.get(attr_name)) is None:
return importlib.import_module(f".{attr_name}", package=__name__)

from importlib import import_module
package, mname = dynamic_attr

if module_name == "__module__":
return import_module(f".{attr_name}", package=package)
try:
if mname == "__module__":
return importlib.import_module(f".{attr_name}", package=package)
else:
module = import_module(module_name, package=package)
module = importlib.import_module(mname, package=package)
return getattr(module, attr_name)
except ModuleNotFoundError as ex:
module, _, attribute = ex.name.rpartition(".")
raise AttributeError(f"module {module} has no attribute {attribute}") from ex
mname, _, attr = (ex.name or "").rpartition(".")
ctx = {"name": mname, "obj": attr} if sys.version_info >= (3, 10) else {}
raise AttributeError(f"module {mname} has no attribute {attr}", **ctx) from ex

0 comments on commit 34faa7b

Please sign in to comment.