Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-117378: Fix multiprocessing forkserver preload sys.path inheritance. #126538

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Lib/multiprocessing/forkserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def ensure_running(self):
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
'''Run forkserver.'''
if preload:
if sys_path is not None:
sys.path[:] = sys_path
if '__main__' in preload and main_path is not None:
process.current_process()._inheriting = True
try:
Expand Down
77 changes: 77 additions & 0 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import os
import gc
import importlib
import errno
import functools
import signal
Expand All @@ -20,8 +21,10 @@
import socket
import random
import logging
import shutil
import subprocess
import struct
import tempfile
import operator
import pickle
import weakref
Expand Down Expand Up @@ -6275,6 +6278,80 @@ def test_atexit(self):
self.assertEqual(f.read(), 'deadbeef')


class _TestSpawnedSysPath(BaseTestCase):
"""Test that sys.path is setup in forkserver and spawn processes."""

ALLOWED_TYPES = ('processes',)

def setUp(self):
self._orig_sys_path = list(sys.path)
self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-")
gpshead marked this conversation as resolved.
Show resolved Hide resolved
self._mod_name = "unique_test_mod"
module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py")
with open(module_path, "w", encoding="utf-8") as mod:
mod.write("# A simple test module\n")
sys.path[:] = [p for p in sys.path if p] # remove any existing ""s
sys.path.insert(0, self._temp_dir)
sys.path.insert(0, "") # Replaced with an abspath in child.
gpshead marked this conversation as resolved.
Show resolved Hide resolved
try:
self._ctx_forkserver = multiprocessing.get_context("forkserver")
except ValueError:
self._ctx_forkserver = None
self._ctx_spawn = multiprocessing.get_context("spawn")

def tearDown(self):
sys.path[:] = self._orig_sys_path
shutil.rmtree(self._temp_dir, ignore_errors=True)

@staticmethod
def enq_imported_module_names(queue):
queue.put(tuple(sys.modules))

def test_forkserver_preload_imports_sys_path(self):
if not (ctx := self._ctx_forkserver):
gpshead marked this conversation as resolved.
Show resolved Hide resolved
self.skipTest("requires forkserver start method.")
self.assertNotIn(self._mod_name, sys.modules)
multiprocessing.forkserver._forkserver._stop() # Must be fresh.
ctx.set_forkserver_preload(
["test.test_multiprocessing_forkserver", self._mod_name])
q = ctx.Queue()
proc = ctx.Process(target=self.enq_imported_module_names, args=(q,))
proc.start()
proc.join()
child_imported_modules = q.get()
q.close()
self.assertIn(self._mod_name, child_imported_modules)

@staticmethod
def enq_sys_path_and_import(queue, mod_name):
queue.put(sys.path)
try:
importlib.import_module(mod_name)
except ImportError as exc:
queue.put(exc)
else:
queue.put(None)

def test_child_sys_path(self):
for ctx in (self._ctx_spawn, self._ctx_forkserver):
if not ctx:
continue
gpshead marked this conversation as resolved.
Show resolved Hide resolved
with self.subTest(f"{ctx.get_start_method()} start method"):
q = ctx.Queue()
proc = ctx.Process(target=self.enq_sys_path_and_import,
args=(q, self._mod_name))
proc.start()
proc.join()
child_sys_path = q.get()
import_error = q.get()
q.close()
self.assertNotIn("", child_sys_path) # replaced by an abspath
self.assertIn(self._temp_dir, child_sys_path) # our addition
# ignore the first element, it is the absolute "" replacement
self.assertEqual(child_sys_path[1:], sys.path[1:])
self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}")


class MiscTestCase(unittest.TestCase):
def test__all__(self):
# Just make sure names in not_exported are excluded
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Fixed the :mod:`multiprocessing` ``"forkserver"`` start method forkserver
process was to correctly inherit the parent's :data:`sys.path` during the
importing of :func:`multiprocessing.set_forkserver_preload` modules in the
same manner as :data:`sys.path` is configured when executing work items in
the worker processes.

This bug could cause some forkserver module preloading to silently fail to
be preloaded, leading to a performance degration in child processes due to
additional repeated work. It could also have led to a side effect of ``""``
still being in :data:`sys.path` during forkserver preload imports instead of
the absolute path of the directory that workers see.
Loading