Skip to content

Commit 7dbbd24

Browse files
authored
Stop HistorySavingThread before fork (ipython#15115)
Python 3.12+ issues a DeprecationWarning if `os.fork()` is called while there are multiple threads, and in fact this is not strictly safe on any version of Python, although on Linux with glibc it mostly works well enough that we don't notice. The man page for `fork()` says: > After a fork() in a multithreaded program, the child can safely call only async-signal-safe functions (see signal-safety(7)) until such time as it calls execve(2). This is incompatible with executing any Python code, hence the warning. IPython is always multi-threaded because of HistorySavingThread, so it's never safe to call `os.fork()` or use multiprocessing's fork context. However, I think we can easily resolve this by stopping the history saving thread before fork, and starting it again after, so at the moment of fork it's a single-threaded process. In principle, I think it's safe to start it again in the `after_in_parent` handler, but Python's check for threads to issue the warning runs after that callback, so to avoid triggering the warning, it's better to start it again when there's something for it to do. This unfortunately only solves the issue for IPython in the terminal, since IPykernel uses several more threads. But fixing it in the terminal is already useful, and this would also be one small step towards fixing it in IPykernel. ----- Testing: ```python import os import time def func_with_fork(): if (pid := os.fork()) > 0: print("Parent") t = os.waitpid(pid, 0) print("Parent: child finished", t) else: print("Child") time.sleep(0.5) print("Child finishing") os._exit(0) if __name__ == "__main__": func_with_fork() ``` ```shell PYTHONWARNINGS=default ipython ```
2 parents 28d9b9c + e602694 commit 7dbbd24

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

IPython/core/history.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import atexit
1010
import datetime
11+
import os
1112
import re
1213

1314

@@ -689,6 +690,7 @@ def __init__(
689690
)
690691
self.hist_file = ":memory:"
691692

693+
self.using_thread = False
692694
if self.enabled and self.hist_file != ":memory:":
693695
self.save_thread = HistorySavingThread(self)
694696
try:
@@ -699,6 +701,8 @@ def __init__(
699701
exc_info=True,
700702
)
701703
self.hist_file = ":memory:"
704+
else:
705+
self.using_thread = True
702706
self._instances.add(self)
703707
assert len(HistoryManager._instances) <= HistoryManager._max_inst, (
704708
len(HistoryManager._instances),
@@ -709,6 +713,20 @@ def __del__(self) -> None:
709713
if self.save_thread is not None:
710714
self.save_thread.stop()
711715

716+
@classmethod
717+
def _stop_thread(cls) -> None:
718+
# Used before forking so the thread isn't running at fork
719+
for inst in cls._instances:
720+
if inst.save_thread is not None:
721+
inst.save_thread.stop()
722+
inst.save_thread = None
723+
724+
def _restart_thread_if_stopped(self) -> None:
725+
# Start the thread again after it was stopped for forking
726+
if self.save_thread is None and self.using_thread:
727+
self.save_thread = HistorySavingThread(self)
728+
self.save_thread.start()
729+
712730
def _get_hist_file_name(self, profile: Optional[str] = None) -> Path:
713731
"""Get default history file name based on the Shell's profile.
714732
@@ -970,8 +988,10 @@ def store_inputs(
970988
self.db_input_cache.append((line_num, source, source_raw))
971989
# Trigger to flush cache and write to DB.
972990
if len(self.db_input_cache) >= self.db_cache_size:
973-
if self.save_flag:
974-
self.save_flag.set()
991+
if self.using_thread:
992+
self._restart_thread_if_stopped()
993+
if self.save_flag is not None:
994+
self.save_flag.set()
975995

976996
# update the auto _i variables
977997
self._iii = self._ii
@@ -1003,8 +1023,10 @@ def store_output(self, line_num: int) -> None:
10031023

10041024
with self.db_output_cache_lock:
10051025
self.db_output_cache.append((line_num, output))
1006-
if self.db_cache_size <= 1 and self.save_flag is not None:
1007-
self.save_flag.set()
1026+
if self.db_cache_size <= 1 and self.using_thread:
1027+
self._restart_thread_if_stopped()
1028+
if self.save_flag is not None:
1029+
self.save_flag.set()
10081030

10091031
def _writeout_input_cache(self, conn: sqlite3.Connection) -> None:
10101032
with conn:
@@ -1059,6 +1081,10 @@ def writeout_cache(self, conn: Optional[sqlite3.Connection] = None) -> None:
10591081
self.db_output_cache = []
10601082

10611083

1084+
if hasattr(os, "register_at_fork"):
1085+
os.register_at_fork(before=HistoryManager._stop_thread)
1086+
1087+
10621088
from collections.abc import Callable, Iterator
10631089
from weakref import ReferenceType
10641090

0 commit comments

Comments
 (0)