|
1 | 1 | import sys
|
| 2 | +import threading |
2 | 3 | from types import ModuleType
|
3 | 4 | from unittest.mock import Mock
|
4 | 5 |
|
5 | 6 | import pytest
|
6 | 7 | import torch.distributed
|
| 8 | +from litdata.streaming.reader import PrepareChunksThread |
7 | 9 |
|
8 | 10 |
|
9 | 11 | @pytest.fixture(autouse=True)
|
@@ -65,3 +67,31 @@ def lightning_sdk_mock(monkeypatch):
|
65 | 67 | lightning_sdk = ModuleType("lightning_sdk")
|
66 | 68 | monkeypatch.setitem(sys.modules, "lightning_sdk", lightning_sdk)
|
67 | 69 | return lightning_sdk
|
| 70 | + |
| 71 | + |
| 72 | +@pytest.fixture(autouse=True) |
| 73 | +def _thread_police(): |
| 74 | + """Attempts to stop left-over threads to avoid test interactions. |
| 75 | +
|
| 76 | + Adapted from PyTorch Lightning. |
| 77 | +
|
| 78 | + """ |
| 79 | + active_threads_before = set(threading.enumerate()) |
| 80 | + yield |
| 81 | + active_threads_after = set(threading.enumerate()) |
| 82 | + |
| 83 | + for thread in active_threads_after - active_threads_before: |
| 84 | + if isinstance(thread, PrepareChunksThread): |
| 85 | + thread.force_stop() |
| 86 | + continue |
| 87 | + |
| 88 | + stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) |
| 89 | + if thread.daemon and callable(stop): |
| 90 | + # A daemon thread would anyway be stopped at the end of a program |
| 91 | + # We do it preemptively here to reduce the risk of interactions with other tests that run after |
| 92 | + stop() |
| 93 | + assert not thread.is_alive() |
| 94 | + elif thread.name == "QueueFeederThread": |
| 95 | + thread.join(timeout=20) |
| 96 | + else: |
| 97 | + raise AssertionError(f"Test left zombie thread: {thread}") |
0 commit comments