Skip to content

Commit

Permalink
meh
Browse files Browse the repository at this point in the history
  • Loading branch information
bartfeenstra committed Jul 28, 2024
1 parent 33bb75c commit 3a62124
Showing 1 changed file with 51 additions and 29 deletions.
80 changes: 51 additions & 29 deletions betty/generate/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import multiprocessing
import os
import queue
from asyncio import run, CancelledError, sleep, create_task
from asyncio import run, CancelledError, sleep, to_thread, create_task
from concurrent import futures
from concurrent.futures import Executor, ProcessPoolExecutor
from contextlib import ExitStack, suppress
from math import floor

from contextlib import suppress, ExitStack
from typing import (
Callable,
MutableSequence,
Expand All @@ -25,10 +23,10 @@
)

import dill
from math import floor

from betty.app import App
from betty.asyncio import gather

from betty.generate import GenerationContext
from betty.project import Project

Expand Down Expand Up @@ -61,13 +59,42 @@ def __init__(self, project: Project):
self._count_total = 0

async def __aenter__(self) -> Self:
try:
await self._start()
except BaseException:
self._cancel.set()
await self._stop()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if exc_val is None:
self._finish.set()
else:
self._cancel.set()
try:
await self._log_jobs()
except BaseException:
self._cancel.set()
raise
finally:
await self._stop()
if exc_val is None:
await self._log_jobs()

async def _start(self) -> None:
concurrency = os.cpu_count() or 2
# Avoid `fork` so as not to start worker processes with unneeded resources.
# Settle for `spawn` so all environments use the same start method.
executor = ProcessPoolExecutor(
max_workers=concurrency, mp_context=multiprocessing.get_context("spawn")
)
self._exit_stack.enter_context(executor)
# We check that the futures are complete in self.__aexit__().
self._exit_stack.callback(lambda: executor.shutdown(wait=False, cancel_futures=True))
# @todo Ensure we pass on the necessary dependencies
pickled_app_args = dill.dumps(
(
Expand All @@ -81,8 +108,10 @@ async def __aenter__(self) -> Self:
"cache_factory": self._project.app._cache_factory,
}
)
# @todo Can we do this without Dill?
pickled_project_args = dill.dumps((self._project.configuration,))
pickled_project_kwargs = dill.dumps({"ancestry": self._project.ancestry})
print(F"WORKER COUNT: {concurrency}")
for _ in range(0, concurrency):
self._workers.append(
executor.submit(
Expand All @@ -101,28 +130,19 @@ async def __aenter__(self) -> Self:
)
)
)
self._exit_stack.callback(create_task(self._log_jobs_forever()).cancel)
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if exc_val is None:
self._finish.set()
else:
self._cancel.set()
try:
for worker in futures.as_completed(self._workers):
worker.result()
except BaseException:
self._cancel.set()
raise
finally:
self._exit_stack.close()
await self._log_jobs()
log_task = create_task(self._log_jobs_forever())
self._exit_stack.callback(log_task.cancel)

async def _stop(self) -> None:
print("WORKERS AWAITING...")
print(self._finish.is_set())
print(self._cancel.is_set())
# @todo NO TIMEOUT
await to_thread(futures.wait, self._workers)
print("WORKERS AWAITED")
print("EXIT STACK CLOSING...")
await to_thread(self._exit_stack.close)
print("EXIT STACK CLOSED")

async def _log_jobs(self) -> None:
total_job_count = self._count_total
Expand Down Expand Up @@ -195,10 +215,11 @@ def __call__(self) -> None:
if self._setup is not None:
self._setup()
run(self._perform_tasks_concurrently())
print('WORKER STOP')

async def _perform_tasks_concurrently(self) -> None:
async with App(
*dill.loads(self._pickled_app_args), **dill.loads(self._pickled_app_kwargs)
*dill.loads(self._pickled_app_args), **dill.loads(self._pickled_app_kwargs),
) as app, Project(
app,
*dill.loads(self._pickled_project_args),
Expand All @@ -214,6 +235,7 @@ async def _perform_tasks_concurrently(self) -> None:

async def _perform_tasks(self, job_context: GenerationContext) -> None:
while not self._cancel.is_set():
print('WORKER ITERATION')
try:
task_callable, task_args, task_kwargs = self._task_queue.get_nowait()
except queue.Empty:
Expand Down

0 comments on commit 3a62124

Please sign in to comment.