From eacdce9ed99cf74f0b8d8ccb8f12233d1b2c4ceb Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 24 Nov 2016 14:49:29 -0800 Subject: [PATCH] Track tasks only during shutdown and tests (#4428) * Track tasks only when needed * Tweak async_block_till_done --- homeassistant/core.py | 67 +++++++++++++++++++++---------------------- tests/common.py | 6 ++-- tests/test_core.py | 14 +++------ tests/test_remote.py | 1 + 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index 645bdc68b0a66..de79843495672 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -113,7 +113,6 @@ def __init__(self, loop=None): self.loop.set_default_executor(self.executor) self.loop.set_exception_handler(self._async_exception_handler) self._pending_tasks = [] - self._pending_sheduler = None self.bus = EventBus(self) self.services = ServiceRegistry(self) self.states = StateMachine(self.bus, self.loop) @@ -185,34 +184,41 @@ def async_start(self): # pylint: disable=protected-access self.loop._thread_ident = threading.get_ident() - self._async_tasks_cleanup() _async_create_timer(self) self.bus.async_fire(EVENT_HOMEASSISTANT_START) self.state = CoreState.running - @callback - def _async_tasks_cleanup(self): - """Cleanup all pending tasks in a time interval. + def add_job(self, target: Callable[..., None], *args: Any) -> None: + """Add job to the executor pool. - This method must be run in the event loop. + target: target to call. + args: parameters for method to call. """ - self._pending_tasks = [task for task in self._pending_tasks - if not task.done()] + self.loop.call_soon_threadsafe(self.async_add_job, target, *args) - # sheduled next cleanup - self._pending_sheduler = self.loop.call_later( - TIME_INTERVAL_TASKS_CLEANUP, self._async_tasks_cleanup) + @callback + def _async_add_job(self, target: Callable[..., None], *args: Any) -> None: + """Add a job from within the eventloop. - def add_job(self, target: Callable[..., None], *args: Any) -> None: - """Add job to the executor pool. + This method must be run in the event loop. target: target to call. args: parameters for method to call. """ - self.loop.call_soon_threadsafe(self.async_add_job, target, *args) + if asyncio.iscoroutine(target): + self.loop.create_task(target) + elif is_callback(target): + self.loop.call_soon(target, *args) + elif asyncio.iscoroutinefunction(target): + self.loop.create_task(target(*args)) + else: + self.loop.run_in_executor(None, target, *args) + + async_add_job = _async_add_job @callback - def async_add_job(self, target: Callable[..., None], *args: Any) -> None: + def _async_add_job_tracking(self, target: Callable[..., None], + *args: Any) -> None: """Add a job from within the eventloop. This method must be run in the event loop. @@ -235,6 +241,11 @@ def async_add_job(self, target: Callable[..., None], *args: Any) -> None: if task is not None: self._pending_tasks.append(task) + @callback + def async_track_tasks(self): + """Track tasks so you can wait for all tasks to be done.""" + self.async_add_job = self._async_add_job_tracking + @callback def async_run_job(self, target: Callable[..., None], *args: Any) -> None: """Run a job from within the event loop. @@ -249,16 +260,6 @@ def async_run_job(self, target: Callable[..., None], *args: Any) -> None: else: self.async_add_job(target, *args) - def _loop_empty(self) -> bool: - """Python 3.4.2 empty loop compatibility function.""" - # pylint: disable=protected-access - if sys.version_info < (3, 4, 3): - return len(self.loop._scheduled) == 0 and \ - len(self.loop._ready) == 0 - else: - return self.loop._current_handle is None and \ - len(self.loop._ready) == 0 - def block_till_done(self) -> None: """Block till all pending work is done.""" run_coroutine_threadsafe( @@ -267,18 +268,17 @@ def block_till_done(self) -> None: @asyncio.coroutine def async_block_till_done(self): """Block till all pending work is done.""" - while True: - # Wait for the pending tasks are down + # To flush out any call_soon_threadsafe + yield from asyncio.sleep(0, loop=self.loop) + + while self._pending_tasks: pending = [task for task in self._pending_tasks if not task.done()] self._pending_tasks.clear() if len(pending) > 0: yield from asyncio.wait(pending, loop=self.loop) - - # Verify the loop is empty - ret = yield from self.loop.run_in_executor(None, self._loop_empty) - if ret and not self._pending_tasks: - break + else: + yield from asyncio.sleep(0, loop=self.loop) def stop(self) -> None: """Stop Home Assistant and shuts down all threads.""" @@ -291,9 +291,8 @@ def async_stop(self) -> None: This method is a coroutine. """ self.state = CoreState.stopping + self.async_track_tasks() self.bus.async_fire(EVENT_HOMEASSISTANT_STOP) - if self._pending_sheduler is not None: - self._pending_sheduler.cancel() yield from self.async_block_till_done() self.executor.shutdown() if self._websession is not None: diff --git a/tests/common.py b/tests/common.py index 525d7f85bd384..25a10783c28ba 100644 --- a/tests/common.py +++ b/tests/common.py @@ -82,6 +82,7 @@ def async_test_home_assistant(loop): loop._thread_ident = threading.get_ident() hass = ha.HomeAssistant(loop) + hass.async_track_tasks() hass.config.location_name = 'test home' hass.config.config_dir = get_test_config_dir() @@ -103,9 +104,8 @@ def async_test_home_assistant(loop): @asyncio.coroutine def mock_async_start(): """Start the mocking.""" - with patch.object(loop, 'add_signal_handler'),\ - patch('homeassistant.core._async_create_timer'),\ - patch.object(hass, '_async_tasks_cleanup', return_value=None): + with patch.object(loop, 'add_signal_handler'), \ + patch('homeassistant.core._async_create_timer'): yield from orig_start() hass.async_start = mock_async_start diff --git a/tests/test_core.py b/tests/test_core.py index 212c6d41f7093..9221ad68352cb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,8 +9,7 @@ import homeassistant.core as ha from homeassistant.exceptions import InvalidEntityFormatError -from homeassistant.util.async import ( - run_callback_threadsafe, run_coroutine_threadsafe) +from homeassistant.util.async import run_coroutine_threadsafe import homeassistant.util.dt as dt_util from homeassistant.util.unit_system import (METRIC_SYSTEM) from homeassistant.const import ( @@ -129,7 +128,7 @@ def test_coro(): """Test Coro.""" call_count.append('call') - for i in range(50): + for i in range(3): self.hass.add_job(test_coro()) run_coroutine_threadsafe( @@ -137,13 +136,8 @@ def test_coro(): loop=self.hass.loop ).result() - with patch.object(self.hass.loop, 'call_later') as mock_later: - run_callback_threadsafe( - self.hass.loop, self.hass._async_tasks_cleanup).result() - assert mock_later.called - - assert len(self.hass._pending_tasks) == 0 - assert len(call_count) == 50 + assert len(self.hass._pending_tasks) == 3 + assert len(call_count) == 3 def test_async_add_job_pending_tasks_coro(self): """Add a coro to pending tasks.""" diff --git a/tests/test_remote.py b/tests/test_remote.py index 55d8ca18b5fc8..fa2a53a96cb73 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -61,6 +61,7 @@ def setUpModule(): target=loop.run_forever).start() slave = remote.HomeAssistant(master_api, loop=loop) + slave.async_track_tasks() slave.config.config_dir = get_test_config_dir() slave.config.skip_pip = True bootstrap.setup_component(