Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-117378: Fix multiprocessing forkserver preload sys.path inheritance. #126538

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Lib/multiprocessing/forkserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def ensure_running(self):
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
'''Run forkserver.'''
if preload:
if sys_path is not None:
sys.path[:] = sys_path
if '__main__' in preload and main_path is not None:
process.current_process()._inheriting = True
try:
Expand Down
78 changes: 78 additions & 0 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import os
import gc
import importlib
import errno
import functools
import signal
Expand All @@ -20,8 +21,10 @@
import socket
import random
import logging
import shutil
import subprocess
import struct
import tempfile
import operator
import pickle
import weakref
Expand Down Expand Up @@ -6275,6 +6278,81 @@ def test_atexit(self):
self.assertEqual(f.read(), 'deadbeef')


class _TestSpawnedSysPath(BaseTestCase):
"""Test that sys.path is setup in forkserver and spawn processes."""

ALLOWED_TYPES = ('processes',)

def setUp(self):
self._orig_sys_path = list(sys.path)
self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-")
gpshead marked this conversation as resolved.
Show resolved Hide resolved
self._mod_name = "unique_test_mod"
module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py")
with open(module_path, "w", encoding="utf-8") as mod:
mod.write("# A simple test module\n")
sys.path[:] = [p for p in sys.path if p] # remove any existing ""s
sys.path.insert(0, self._temp_dir)
sys.path.insert(0, "") # Replaced with an abspath in child.
gpshead marked this conversation as resolved.
Show resolved Hide resolved
try:
self._ctx_forkserver = multiprocessing.get_context("forkserver")
except ValueError:
self._ctx_forkserver = None
self._ctx_spawn = multiprocessing.get_context("spawn")

def tearDown(self):
sys.path[:] = self._orig_sys_path
shutil.rmtree(self._temp_dir, ignore_errors=True)

@staticmethod
def enq_imported_module_names(queue):
queue.put(tuple(sys.modules))

def test_forkserver_preload_imports_sys_path(self):
ctx = self._ctx_forkserver
if not ctx:
self.skipTest("requires forkserver start method.")
self.assertNotIn(self._mod_name, sys.modules)
multiprocessing.forkserver._forkserver._stop() # Must be fresh.
ctx.set_forkserver_preload(
["test.test_multiprocessing_forkserver", self._mod_name])
q = ctx.Queue()
proc = ctx.Process(target=self.enq_imported_module_names, args=(q,))
proc.start()
proc.join()
child_imported_modules = q.get()
q.close()
self.assertIn(self._mod_name, child_imported_modules)

@staticmethod
def enq_sys_path_and_import(queue, mod_name):
queue.put(sys.path)
try:
importlib.import_module(mod_name)
except ImportError as exc:
queue.put(exc)
else:
queue.put(None)

def test_child_sys_path(self):
for ctx in (self._ctx_spawn, self._ctx_forkserver):
if not ctx:
continue
gpshead marked this conversation as resolved.
Show resolved Hide resolved
with self.subTest(f"{ctx.get_start_method()} start method"):
q = ctx.Queue()
proc = ctx.Process(target=self.enq_sys_path_and_import,
args=(q, self._mod_name))
proc.start()
proc.join()
child_sys_path = q.get()
import_error = q.get()
q.close()
self.assertNotIn("", child_sys_path) # replaced by an abspath
self.assertIn(self._temp_dir, child_sys_path) # our addition
# ignore the first element, it is the absolute "" replacement
self.assertEqual(child_sys_path[1:], sys.path[1:])
self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}")


class MiscTestCase(unittest.TestCase):
def test__all__(self):
# Just make sure names in not_exported are excluded
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Fixed the :mod:`multiprocessing` ``"forkserver"`` start method forkserver
process to correctly inherit the parent's :data:`sys.path` during the importing
of :func:`multiprocessing.set_forkserver_preload` modules in the same manner as
:data:`sys.path` is configured in workers before executing work items.

This bug caused some forkserver module preloading to silently fail to preload.
This manifested as a performance degration in child processes when the
``sys.path`` was required due to additional repeated work in every worker.

It could also have a side effect of ``""`` remaining in :data:`sys.path` during
forkserver preload imports instead of the absolute path from :func:`os.getcwd`
at multiprocessing import time used in the worker ``sys.path``.

Potentially leading to incorrect imports from the wrong location during
preload. We are unaware of that actually happening. The issue was discovered
by someone observing unexpected preload performance gains.

Loading