Skip to content

Commit fd411e5

Browse files
authored
Terminate threads to avoid test interactions (#244)
1 parent c58b673 commit fd411e5

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import sys
2+
import threading
23
from types import ModuleType
34
from unittest.mock import Mock
45

56
import pytest
67
import torch.distributed
8+
from litdata.streaming.reader import PrepareChunksThread
79

810

911
@pytest.fixture(autouse=True)
@@ -65,3 +67,31 @@ def lightning_sdk_mock(monkeypatch):
6567
lightning_sdk = ModuleType("lightning_sdk")
6668
monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk)
6769
return lightning_sdk
70+
71+
72+
@pytest.fixture(autouse=True)
73+
def _thread_police():
74+
"""Attempts to stop left-over threads to avoid test interactions.
75+
76+
Adapted from PyTorch Lightning.
77+
78+
"""
79+
active_threads_before = set(threading.enumerate())
80+
yield
81+
active_threads_after = set(threading.enumerate())
82+
83+
for thread in active_threads_after - active_threads_before:
84+
if isinstance(thread, PrepareChunksThread):
85+
thread.force_stop()
86+
continue
87+
88+
stop = getattr(thread, "stop", None) or getattr(thread, "exit", None)
89+
if thread.daemon and callable(stop):
90+
# A daemon thread would anyway be stopped at the end of a program
91+
# We do it preemptively here to reduce the risk of interactions with other tests that run after
92+
stop()
93+
assert not thread.is_alive()
94+
elif thread.name == "QueueFeederThread":
95+
thread.join(timeout=20)
96+
else:
97+
raise AssertionError(f"Test left zombie thread: {thread}")

0 commit comments

Comments
 (0)