Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions injection/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +0,0 @@
from importlib.util import find_spec
from typing import Literal

__all__ = ("_is_installed",)


def _is_installed(package: str, needed_for: object, /) -> Literal[True]:
if find_spec(package) is None:
raise RuntimeError(f"To use `{needed_for}`, {package} must be installed.")

return True
6 changes: 2 additions & 4 deletions injection/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from types import GenericAlias
from typing import Any, TypeAliasType

from fastapi import Depends

from injection import Module, mod
from injection.exceptions import InjectionError
from injection.integrations import _is_installed

__all__ = ("Inject",)

if _is_installed("fastapi", __name__):
from fastapi import Depends


def Inject[T]( # noqa: N802
cls: type[T] | TypeAliasType | GenericAlias,
Expand Down
36 changes: 34 additions & 2 deletions injection/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from collections.abc import Callable, Iterator
from collections.abc import Callable, Iterable, Iterator
from contextlib import contextmanager
from importlib import import_module
from importlib.util import find_spec
from pkgutil import walk_packages
from types import ModuleType as PythonModule
from typing import ContextManager

from injection import __name__ as injection_package_name
from injection import mod

__all__ = ("load_packages", "load_profile")
__all__ = ("load_modules_with_keywords", "load_packages", "load_profile")


def load_profile(*names: str) -> ContextManager[None]:
Expand All @@ -33,6 +35,36 @@ def cleaner() -> Iterator[None]:
return cleaner()


def load_modules_with_keywords(
*packages: PythonModule | str,
keywords: Iterable[str] | None = None,
) -> dict[str, PythonModule]:
"""
Function to import modules from a Python package if one of the keywords is contained in the Python script.
The default keywords are:
- `from injection`
- `import injection`
"""

if keywords is None:
keywords = f"from {injection_package_name}", f"import {injection_package_name}"

b_keywords = tuple(keyword.encode() for keyword in keywords)

def predicate(module_name: str) -> bool:
if (spec := find_spec(module_name)) and (module_path := spec.origin):
with open(module_path, "rb") as script:
for line in script:
line = b" ".join(line.split(b" ")).strip()

if line and any(keyword in line for keyword in b_keywords):
return True

return False

return load_packages(*packages, predicate=predicate)


def load_packages(
*packages: PythonModule | str,
predicate: Callable[[str], bool] = lambda module_name: True,
Expand Down
File renamed without changes.
File renamed without changes.
Empty file.
Empty file added tests/utils/package2/module.py
Empty file.
Empty file.
5 changes: 5 additions & 0 deletions tests/utils/package2/sub_package/injectable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from injection import injectable


@injectable
class SomeInjectable: ...
14 changes: 14 additions & 0 deletions tests/utils/test_load_modules_with_keywords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys

from injection.utils import load_modules_with_keywords


class TestLoadModulesWithKeywords:
def test_load_modules_with_keywords_with_success(self):
from tests.utils import package2

loaded_modules = load_modules_with_keywords(package2)
assert len(loaded_modules) == 1
assert "tests.utils.package2.sub_package.injectable" in loaded_modules
assert "tests.utils.package2.sub_package.injectable" in sys.modules
assert "tests.utils.package2.module" not in sys.modules
40 changes: 20 additions & 20 deletions tests/utils/test_load_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,53 @@
from injection.utils import load_packages


class TestLoadPackage:
def test_load_package_with_predicate(self):
from tests.utils import package
class TestLoadPackages:
def test_load_packages_with_predicate(self):
from tests.utils import package1

loaded_modules = load_packages(
package,
package1,
predicate=lambda name: ".excluded_package." not in name,
)

assert "tests.utils.package.excluded_package.module3" not in loaded_modules
assert "tests.utils.package1.excluded_package.module3" not in loaded_modules

modules = (
"tests.utils.package.module1",
"tests.utils.package.sub_package.module2",
"tests.utils.package1.module1",
"tests.utils.package1.sub_package.module2",
)

for module in modules:
assert module in loaded_modules

def test_load_package_with_success(self):
from tests.utils import package
def test_load_packages_with_success(self):
from tests.utils import package1

loaded_modules = load_packages(package)
loaded_modules = load_packages(package1)

modules = (
"tests.utils.package.module1",
"tests.utils.package.sub_package.module2",
"tests.utils.package.excluded_package.module3",
"tests.utils.package1.module1",
"tests.utils.package1.sub_package.module2",
"tests.utils.package1.excluded_package.module3",
)

for module in modules:
assert module in loaded_modules

def test_load_package_with_str(self):
loaded_modules = load_packages("tests.utils.package")
def test_load_packages_with_str(self):
loaded_modules = load_packages("tests.utils.package1")

modules = (
"tests.utils.package.module1",
"tests.utils.package.sub_package.module2",
"tests.utils.package.excluded_package.module3",
"tests.utils.package1.module1",
"tests.utils.package1.sub_package.module2",
"tests.utils.package1.excluded_package.module3",
)

for module in modules:
assert module in loaded_modules

def test_load_package_with_module_raise_type_error(self):
from tests.utils.package import module1
def test_load_packages_with_module_raise_type_error(self):
from tests.utils.package1 import module1

with pytest.raises(TypeError):
load_packages(module1)
Loading