Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def launch(self) -> NoReturn:
logger.info("Stopping microservice...")
exception = exc
except Exception as exc: # pragma: no cover
logger.exception("Stopping microservice due to an unhandled exception...")
exception = exc
finally:
self.graceful_shutdown(exception)
Expand All @@ -143,14 +144,14 @@ def graceful_launch(self) -> None:

:return: This method does not return anything.
"""
self.loop.run_until_complete(gather(self.setup(), self.entrypoint.__aenter__()))
self.loop.run_until_complete(self.setup())

def graceful_shutdown(self, err: Exception = None) -> None:
"""Shutdown the execution gracefully.

:return: This method does not return anything.
"""
self.loop.run_until_complete(gather(self.entrypoint.__aexit__(None, err, None), self.destroy()))
self.loop.run_until_complete(self.destroy())

@cached_property
def entrypoint(self) -> Entrypoint:
Expand Down Expand Up @@ -199,9 +200,10 @@ async def _setup(self) -> None:

:return: This method does not return anything.
"""
await self.injector.wire_and_setup_injections(
modules=self._external_modules + self._internal_modules, packages=self._external_packages
)
modules = self._external_modules + self._internal_modules
packages = self._external_packages
self.injector.wire_injections(modules=modules, packages=packages)
await gather(self.injector.setup_injections(), self.entrypoint.__aenter__())

@property
def _internal_modules(self) -> list[ModuleType]:
Expand All @@ -212,7 +214,8 @@ async def _destroy(self) -> None:

:return: This method does not return anything.
"""
await self.injector.unwire_and_destroy_injections()
await gather(self.entrypoint.__aexit__(None, None, None), self.injector.destroy_injections())
self.injector.unwire_injections()

@property
def injections(self) -> dict[str, InjectableMixin]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from unittest.mock import (
AsyncMock,
MagicMock,
call,
patch,
)
Expand Down Expand Up @@ -106,37 +107,67 @@ async def test_loop(self):
self.assertEqual(call(), mock_loop.call_args)

async def test_setup(self):
mock = AsyncMock()
self.launcher.injector.wire_and_setup_injections = mock
await self.launcher.setup()
wire_mock = MagicMock()
setup_mock = AsyncMock()
mock_entrypoint_aenter = AsyncMock()

self.launcher.injector.wire_injections = wire_mock
self.launcher.injector.setup_injections = setup_mock

with patch("minos.common.launchers._create_loop") as mock_loop:
loop = FakeLoop()
mock_loop.return_value = loop
with patch("minos.common.launchers._create_entrypoint") as mock_entrypoint:
entrypoint = FakeEntrypoint()
mock_entrypoint.return_value = entrypoint

entrypoint.__aenter__ = mock_entrypoint_aenter

await self.launcher.setup()

self.assertEqual(1, mock.call_count)
self.assertEqual(1, wire_mock.call_count)
import tests
from minos import (
common,
)

self.assertEqual(0, len(mock.call_args.args))
self.assertEqual(2, len(mock.call_args.kwargs))
observed = mock.call_args.kwargs["modules"]
self.assertEqual(0, len(wire_mock.call_args.args))
self.assertEqual(2, len(wire_mock.call_args.kwargs))
observed = wire_mock.call_args.kwargs["modules"]

self.assertIn(tests, observed)
self.assertIn(common, observed)

self.assertEqual(["tests"], mock.call_args.kwargs["packages"])
self.assertEqual(["tests"], wire_mock.call_args.kwargs["packages"])

await self.launcher.destroy()
self.assertEqual(1, setup_mock.call_count)
self.assertEqual(1, mock_entrypoint_aenter.call_count)

async def test_destroy(self):
self.launcher.injector.wire_and_setup_injections = AsyncMock()
self.launcher._setup = AsyncMock()
await self.launcher.setup()

mock = AsyncMock()
self.launcher.injector.unwire_and_destroy_injections = mock
await self.launcher.destroy()
destroy_mock = AsyncMock()
unwire_mock = MagicMock()
mock_entrypoint_aexit = AsyncMock()

self.launcher.injector.destroy_injections = destroy_mock
self.launcher.injector.unwire_injections = unwire_mock

with patch("minos.common.launchers._create_loop") as mock_loop:
loop = FakeLoop()
mock_loop.return_value = loop
with patch("minos.common.launchers._create_entrypoint") as mock_entrypoint:
entrypoint = FakeEntrypoint()
mock_entrypoint.return_value = entrypoint

entrypoint.__aexit__ = mock_entrypoint_aexit

await self.launcher.destroy()

self.assertEqual(1, mock.call_count)
self.assertEqual(call(), mock.call_args)
self.assertEqual(1, unwire_mock.call_count)
self.assertEqual(1, destroy_mock.call_count)
self.assertEqual(1, mock_entrypoint_aexit.call_count)

def test_launch(self):
mock_setup = AsyncMock()
Expand All @@ -157,8 +188,8 @@ def test_launch(self):

self.launcher.launch()

self.assertEqual(1, mock_entrypoint.call_count)
self.assertEqual(1, mock_loop.call_count)
self.assertEqual(1, mock_setup.call_count)
self.assertEqual(1, mock_destroy.call_count)


class TestEntryPointLauncherLoop(unittest.TestCase):
Expand Down