Skip to content

Commit

Permalink
Make site generation use multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
bartfeenstra committed Jul 28, 2024
1 parent f735b76 commit 33bb75c
Show file tree
Hide file tree
Showing 10 changed files with 403 additions and 200 deletions.
4 changes: 2 additions & 2 deletions betty/cache/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __init__(
self._scopes = scopes or ()
self._scoped_caches: dict[str, Self] = {}
self._locks: MutableMapping[str, _Lock] = defaultdict(
AsynchronizedLock.threading
AsynchronizedLock.multiprocessing
)
self._locks_lock = AsynchronizedLock.threading()
self._locks_lock = AsynchronizedLock.multiprocessing()

async def _lock(self, cache_item_id: str) -> _Lock:
async with self._locks_lock:
Expand Down
8 changes: 8 additions & 0 deletions betty/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import multiprocessing
import threading
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -81,6 +82,13 @@ async def acquire(self, *, wait: bool = True) -> bool:
def release(self) -> None:
self._lock.release()

@classmethod
def multiprocessing(cls) -> Self:
"""
Create a new multiprocessing-safe, asynchronous lock.
"""
return cls(multiprocessing.Manager().Lock())

@classmethod
def threading(cls) -> Self:
"""
Expand Down
92 changes: 92 additions & 0 deletions betty/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Provide the Generation API.
"""

from __future__ import annotations

import asyncio
import logging
import os
import shutil
from contextlib import suppress
from pathlib import Path

from aiofiles.os import makedirs

from betty.job import Context
from betty.project import Project
from betty.project import ProjectEvent


class GenerateSiteEvent(ProjectEvent):
"""
Dispatched to generate a project's site.
"""

def __init__(self, job_context: GenerationContext):
super().__init__(job_context.project)
self._job_context = job_context

@property
def job_context(self) -> GenerationContext:
"""
The site generation job context.
"""
return self._job_context


class GenerationContext(Context):
"""
A site generation job context.
"""

def __init__(self, project: Project):
super().__init__()
self._project = project

@property
def project(self) -> Project:
"""
The Betty project this job context is run within.
"""
return self._project


async def generate(project: Project) -> None:
"""
Generate a new site.
"""
from betty.generate.pool import _GenerationProcessPool
from betty.generate.task import _generate_delegate, _generate_static_public

logger = logging.getLogger(__name__)
job_context = GenerationContext(project)
app = project.app

logger.info(
app.localizer._("Generating your site to {output_directory}.").format(
output_directory=project.configuration.output_directory_path
)
)
with suppress(FileNotFoundError):
await asyncio.to_thread(
shutil.rmtree, project.configuration.output_directory_path
)
await makedirs(project.configuration.output_directory_path, exist_ok=True)

# The static public assets may be overridden depending on the number of locales rendered, so ensure they are
# generated before anything else.
await _generate_static_public(job_context)

async with _GenerationProcessPool(project) as process_pool:
await _generate_delegate(project, process_pool)

project.configuration.output_directory_path.chmod(0o755)
for directory_path_str, subdirectory_names, file_names in os.walk(
project.configuration.output_directory_path
):
directory_path = Path(directory_path_str)
for subdirectory_name in subdirectory_names:
(directory_path / subdirectory_name).chmod(0o755)
for file_name in file_names:
(directory_path / file_name).chmod(0o644)
39 changes: 39 additions & 0 deletions betty/generate/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
File utilities for site generation.
"""

from __future__ import annotations

from typing import AsyncContextManager, cast, TYPE_CHECKING

import aiofiles
from aiofiles.os import makedirs
from aiofiles.threadpool.text import AsyncTextIOWrapper

if TYPE_CHECKING:
from pathlib import Path


async def create_file(path: Path) -> AsyncContextManager[AsyncTextIOWrapper]:
"""
Create the file for a resource.
"""
await makedirs(path.parent, exist_ok=True)
return cast(
AsyncContextManager[AsyncTextIOWrapper],
aiofiles.open(path, "w", encoding="utf-8"),
)


async def create_html_resource(path: Path) -> AsyncContextManager[AsyncTextIOWrapper]:
"""
Create the file for an HTML resource.
"""
return await create_file(path / "index.html")


async def create_json_resource(path: Path) -> AsyncContextManager[AsyncTextIOWrapper]:
"""
Create the file for a JSON resource.
"""
return await create_file(path / "index.json")
227 changes: 227 additions & 0 deletions betty/generate/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
Provide the site generation multiprocessing pool.
"""

from __future__ import annotations

import logging
import multiprocessing
import os
import queue
from asyncio import run, CancelledError, sleep, create_task
from concurrent import futures
from concurrent.futures import Executor, ProcessPoolExecutor
from contextlib import ExitStack, suppress
from math import floor

from typing import (
Callable,
MutableSequence,
Self,
Concatenate,
Any,
TYPE_CHECKING,
ParamSpec,
)

import dill

from betty.app import App
from betty.asyncio import gather

from betty.generate import GenerationContext
from betty.project import Project

if TYPE_CHECKING:
from types import TracebackType
import threading


_GenerationProcessPoolTaskP = ParamSpec("_GenerationProcessPoolTaskP")

worker_setup: Callable[[], None] | None = None


class _GenerationProcessPool:
"""
Set up a worker process, before the worker starts performing tasks.
This may be used to modify the environment, set up mocks, etc.
"""

def __init__(self, project: Project):
self._project = project
self._queue = multiprocessing.Manager().Queue()
self._cancel = multiprocessing.Manager().Event()
self._finish = multiprocessing.Manager().Event()
self._exit_stack = ExitStack()
self._executor: Executor | None = None
self._workers: MutableSequence[futures.Future[None]] = []
# @todo Ensure this is synchronized across processes.
self._count_total = 0

async def __aenter__(self) -> Self:
concurrency = os.cpu_count() or 2
# Avoid `fork` so as not to start worker processes with unneeded resources.
# Settle for `spawn` so all environments use the same start method.
executor = ProcessPoolExecutor(
max_workers=concurrency, mp_context=multiprocessing.get_context("spawn")
)
self._exit_stack.enter_context(executor)
# @todo Ensure we pass on the necessary dependencies
pickled_app_args = dill.dumps(
(
self._project.app.configuration,
self._project.app._cache_directory_path,
)
)
pickled_app_kwargs = dill.dumps(
{
# @todo Give it the actual cache.
"cache_factory": self._project.app._cache_factory,
}
)
pickled_project_args = dill.dumps((self._project.configuration,))
pickled_project_kwargs = dill.dumps({"ancestry": self._project.ancestry})
for _ in range(0, concurrency):
self._workers.append(
executor.submit(
_GenerationProcessPoolWorker(
self._queue,
self._cancel,
self._finish,
concurrency,
# This is an optional, pickleable callable that can be set during a test.
# All we do here is to ensure each worker has access to it.
worker_setup,
pickled_app_args=pickled_app_args,
pickled_app_kwargs=pickled_app_kwargs,
pickled_project_args=pickled_project_args,
pickled_project_kwargs=pickled_project_kwargs,
)
)
)
self._exit_stack.callback(create_task(self._log_jobs_forever()).cancel)
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if exc_val is None:
self._finish.set()
else:
self._cancel.set()
try:
for worker in futures.as_completed(self._workers):
worker.result()
except BaseException:
self._cancel.set()
raise
finally:
self._exit_stack.close()
await self._log_jobs()

async def _log_jobs(self) -> None:
total_job_count = self._count_total
completed_job_count = total_job_count - self._queue.qsize()
logging.getLogger(__name__).info(
self._project.app.localizer._(
"Generated {completed_job_count} out of {total_job_count} items ({completed_job_percentage}%)."
).format(
completed_job_count=completed_job_count,
total_job_count=total_job_count,
completed_job_percentage=floor(
completed_job_count / (total_job_count / 100)
if total_job_count > 0
else 0
),
)
)

async def _log_jobs_forever(self) -> None:
with suppress(CancelledError):
while True:
await self._log_jobs()
await sleep(5)

def delegate(
self,
task_callable: Callable[
Concatenate[GenerationContext, _GenerationProcessPoolTaskP], Any
],
*task_args: _GenerationProcessPoolTaskP.args,
**task_kwargs: _GenerationProcessPoolTaskP.kwargs,
) -> None:
self._queue.put((task_callable, task_args, task_kwargs))
self._count_total += 1


class _GenerationProcessPoolWorker:
def __init__(
self,
task_queue: queue.Queue[
tuple[
Callable[
Concatenate[GenerationContext, _GenerationProcessPoolTaskP], Any
],
_GenerationProcessPoolTaskP.args,
_GenerationProcessPoolTaskP.kwargs,
]
],
cancel: threading.Event,
finish: threading.Event,
async_concurrency: int,
setup: Callable[[], None] | None,
*,
pickled_app_args: bytes,
pickled_app_kwargs: bytes,
pickled_project_args: bytes,
pickled_project_kwargs: bytes,
):
self._task_queue = task_queue
self._cancel = cancel
self._finish = finish
self._setup = setup
self._async_concurrency = async_concurrency
self._pickled_app_args = pickled_app_args
self._pickled_app_kwargs = pickled_app_kwargs
self._pickled_project_args = pickled_project_args
self._pickled_project_kwargs = pickled_project_kwargs

def __call__(self) -> None:
if self._setup is not None:
self._setup()
run(self._perform_tasks_concurrently())

async def _perform_tasks_concurrently(self) -> None:
async with App(
*dill.loads(self._pickled_app_args), **dill.loads(self._pickled_app_kwargs)
) as app, Project(
app,
*dill.loads(self._pickled_project_args),
**dill.loads(self._pickled_project_kwargs),
) as project:
job_context = GenerationContext(project)
await gather(
*(
self._perform_tasks(job_context)
for _ in range(0, self._async_concurrency)
)
)

async def _perform_tasks(self, job_context: GenerationContext) -> None:
while not self._cancel.is_set():
try:
task_callable, task_args, task_kwargs = self._task_queue.get_nowait()
except queue.Empty:
if self._finish.is_set():
return
else:
await task_callable(
job_context,
*task_args,
**task_kwargs,
)
Loading

0 comments on commit 33bb75c

Please sign in to comment.