|
12 | 12 | import sys
|
13 | 13 | import os
|
14 | 14 | import gc
|
| 15 | +import importlib |
15 | 16 | import errno
|
16 | 17 | import functools
|
17 | 18 | import signal
|
|
20 | 21 | import socket
|
21 | 22 | import random
|
22 | 23 | import logging
|
| 24 | +import shutil |
23 | 25 | import subprocess
|
24 | 26 | import struct
|
| 27 | +import tempfile |
25 | 28 | import operator
|
26 | 29 | import pickle
|
27 | 30 | import weakref
|
@@ -6275,6 +6278,80 @@ def test_atexit(self):
|
6275 | 6278 | self.assertEqual(f.read(), 'deadbeef')
|
6276 | 6279 |
|
6277 | 6280 |
|
| 6281 | +class _TestSpawnedSysPath(BaseTestCase): |
| 6282 | + """Test that sys.path is setup in forkserver and spawn processes.""" |
| 6283 | + |
| 6284 | + ALLOWED_TYPES = ('processes',) |
| 6285 | + |
| 6286 | + def setUp(self): |
| 6287 | + self._orig_sys_path = list(sys.path) |
| 6288 | + self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-") |
| 6289 | + self._mod_name = "unique_test_mod" |
| 6290 | + module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py") |
| 6291 | + with open(module_path, "w", encoding="utf-8") as mod: |
| 6292 | + mod.write("# A simple test module\n") |
| 6293 | + sys.path[:] = [p for p in sys.path if p] # remove any existing ""s |
| 6294 | + sys.path.insert(0, self._temp_dir) |
| 6295 | + sys.path.insert(0, "") # Replaced with an abspath in child. |
| 6296 | + try: |
| 6297 | + self._ctx_forkserver = multiprocessing.get_context("forkserver") |
| 6298 | + except ValueError: |
| 6299 | + self._ctx_forkserver = None |
| 6300 | + self._ctx_spawn = multiprocessing.get_context("spawn") |
| 6301 | + |
| 6302 | + def tearDown(self): |
| 6303 | + sys.path[:] = self._orig_sys_path |
| 6304 | + shutil.rmtree(self._temp_dir, ignore_errors=True) |
| 6305 | + |
| 6306 | + @staticmethod |
| 6307 | + def enq_imported_module_names(queue): |
| 6308 | + queue.put(tuple(sys.modules)) |
| 6309 | + |
| 6310 | + def test_forkserver_preload_imports_sys_path(self): |
| 6311 | + if not (ctx := self._ctx_forkserver): |
| 6312 | + self.skipTest("requires forkserver start method.") |
| 6313 | + self.assertNotIn(self._mod_name, sys.modules) |
| 6314 | + multiprocessing.forkserver._forkserver._stop() # Must be fresh. |
| 6315 | + ctx.set_forkserver_preload( |
| 6316 | + ["test.test_multiprocessing_forkserver", self._mod_name]) |
| 6317 | + q = ctx.Queue() |
| 6318 | + proc = ctx.Process(target=self.enq_imported_module_names, args=(q,)) |
| 6319 | + proc.start() |
| 6320 | + proc.join() |
| 6321 | + child_imported_modules = q.get() |
| 6322 | + q.close() |
| 6323 | + self.assertIn(self._mod_name, child_imported_modules) |
| 6324 | + |
| 6325 | + @staticmethod |
| 6326 | + def enq_sys_path_and_import(queue, mod_name): |
| 6327 | + queue.put(sys.path) |
| 6328 | + try: |
| 6329 | + importlib.import_module(mod_name) |
| 6330 | + except ImportError as exc: |
| 6331 | + queue.put(exc) |
| 6332 | + else: |
| 6333 | + queue.put(None) |
| 6334 | + |
| 6335 | + def test_child_sys_path(self): |
| 6336 | + for ctx in (self._ctx_spawn, self._ctx_forkserver): |
| 6337 | + if not ctx: |
| 6338 | + continue |
| 6339 | + with self.subTest(f"{ctx.get_start_method()} start method"): |
| 6340 | + q = ctx.Queue() |
| 6341 | + proc = ctx.Process(target=self.enq_sys_path_and_import, |
| 6342 | + args=(q, self._mod_name)) |
| 6343 | + proc.start() |
| 6344 | + proc.join() |
| 6345 | + child_sys_path = q.get() |
| 6346 | + import_error = q.get() |
| 6347 | + q.close() |
| 6348 | + self.assertNotIn("", child_sys_path) # replaced by an abspath |
| 6349 | + self.assertIn(self._temp_dir, child_sys_path) # our addition |
| 6350 | + # ignore the first element, it is the absolute "" replacement |
| 6351 | + self.assertEqual(child_sys_path[1:], sys.path[1:]) |
| 6352 | + self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") |
| 6353 | + |
| 6354 | + |
6278 | 6355 | class MiscTestCase(unittest.TestCase):
|
6279 | 6356 | def test__all__(self):
|
6280 | 6357 | # Just make sure names in not_exported are excluded
|
|
0 commit comments