Skip to content
4 changes: 3 additions & 1 deletion dlt/common/runners/pool_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,12 @@ def _run_func() -> bool:
try:
logger.debug("Running pool")
while _run_func():
# for next run
# raise on signal: safe to do that out of _run_func()
signals.raise_if_signalled()
runs_count += 1
sleep(config.run_sleep)
# signal could come
signals.raise_if_signalled()
return runs_count
except SignalReceivedException as sigex:
# sleep this may raise SignalReceivedException
Expand Down
5 changes: 5 additions & 0 deletions dlt/common/runtime/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DictCollector(Collector):
"""A collector that just counts"""

def __init__(self) -> None:
self.step = None
self.counters: DefaultDict[str, int] = None

def update(
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
log_level (str, optional): Log level for the logger. Defaults to INFO level
dump_system_stats (bool, optional): Log memory and cpu usage. Defaults to True
"""
self.step = None
self.log_period = log_period
self.logger = logger
self.log_level = log_level
Expand Down Expand Up @@ -260,6 +262,7 @@ def __init__(self, single_bar: bool = False, **tqdm_kwargs: Any) -> None:
raise MissingDependencyException(
"TqdmCollector", ["tqdm"], "We need tqdm to display progress bars."
)
self.step = None
self.single_bar = single_bar
self._bars: Dict[str, tqdm[None]] = {}
self.tqdm_kwargs = tqdm_kwargs or {}
Expand Down Expand Up @@ -321,6 +324,7 @@ def __init__(self, single_bar: bool = True, **alive_kwargs: Any) -> None:
["alive-progress"],
"We need alive-progress to display progress bars.",
)
self.step = None
self.single_bar = single_bar
self._bars: Dict[str, Any] = {}
self._bars_counts: Dict[str, int] = {}
Expand Down Expand Up @@ -399,6 +403,7 @@ def __init__(self, single_bar: bool = False, **enlighten_kwargs: Any) -> None:
["enlighten"],
"We need enlighten to display progress bars with a space for log messages.",
)
self.step = None
self.single_bar = single_bar
self.enlighten_kwargs = enlighten_kwargs

Expand Down
109 changes: 88 additions & 21 deletions dlt/common/runtime/signals.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,74 @@
import sys
import threading
import signal
from contextlib import contextmanager
from threading import Event
from typing import Any, Iterator
from types import FrameType
from typing import Any, Callable, Dict, Iterator, Optional, Union

from dlt.common import logger
from dlt.common.exceptions import SignalReceivedException

_received_signal: int = 0
exit_event = Event()
_signal_counts: Dict[int, int] = {}
_original_handlers: Dict[int, Union[int, Callable[[int, Optional[FrameType]], Any]]] = {}


def signal_receiver(sig: int, frame: Any) -> None:
global _received_signal
def signal_receiver(sig: int, frame: FrameType) -> None:
"""Handle POSIX signals with two-stage escalation.

logger.info(f"Signal {sig} received")
This handler is installed by delayed_signals(). On the first occurrence of a
supported signal (eg. SIGINT, SIGTERM) it requests a graceful shutdown by
setting a process-wide flag and waking sleeping threads via exit_event.
A second occurrence of the same signal escalates by delegating to the
original handler or the system default, which typically results in an
immediate process termination (eg. KeyboardInterrupt for SIGINT).

if _received_signal > 0:
logger.info(f"Another signal received after {_received_signal}")
return
Args:
sig: Signal number (for example, signal.SIGINT or signal.SIGTERM).
frame: The current stack frame when the signal was received.

_received_signal = sig
# awake all threads sleeping on event
exit_event.set()
Notes:
- The CPython runtime delivers signal handlers in the main thread only.
Worker threads must cooperatively observe shutdown via raise_if_signalled()
or the signal-aware sleep().
"""
global _received_signal

logger.info("Sleeping threads signalled")
# track how many times this signal type has been received
_signal_counts[sig] = _signal_counts.get(sig, 0) + 1

if _signal_counts[sig] == 1:
# first signal of this type: set flag and wake threads
_received_signal = sig
if sig == signal.SIGINT:
sig_desc = "CTRL-C"
else:
sig_desc = f"Signal {sig}"
msg = (
f"{sig_desc} received. Trying to shut down gracefully. It may take time to drain job"
f" pools. Send {sig_desc} again to force stop."
)
if sys.stdin.isatty():
# log to console
sys.stderr.write(msg)
sys.stderr.flush()
else:
logger.warning(msg)
elif _signal_counts[sig] >= 2:
# Second signal of this type: call original handler
logger.debug(f"Second signal {sig} received, calling default handler")
original_handler = _original_handlers.get(sig, signal.SIG_DFL)
if callable(original_handler):
original_handler(sig, frame)
elif original_handler == signal.SIG_DFL:
# Restore default and re-raise to trigger default behavior
signal.signal(sig, signal.SIG_DFL)
signal.raise_signal(sig)

exit_event.set()
logger.debug("Sleeping threads signalled")


def raise_if_signalled() -> None:
Expand All @@ -38,38 +82,61 @@ def signal_received() -> bool:


def sleep(sleep_seconds: float) -> None:
"""A signal-aware version of sleep function. Will raise SignalReceivedException if signal was received during sleep period."""
# do not allow sleeping if signal was received
raise_if_signalled()
"""A signal-aware version of sleep function. Will wake up if signal is received but will not raise exception."""
# sleep or wait for signal
exit_event.clear()
exit_event.wait(sleep_seconds)
# if signal then raise
raise_if_signalled()


def wake_all() -> None:
"""Wakes all threads sleeping on event"""
exit_event.set()


def _clear_signals() -> None:
global _received_signal

_received_signal = 0
_signal_counts.clear()
_original_handlers.clear()


@contextmanager
def delayed_signals() -> Iterator[None]:
"""Will delay signalling until `raise_if_signalled` is used or signalled `sleep`"""
"""Will delay signalling until `raise_if_signalled` is explicitly used or when
a second signal with the same int value arrives.

A no-op when not called on main thread.

Can be nested - nested calls are no-ops.
"""

if threading.current_thread() is threading.main_thread():
original_sigint_handler = signal.getsignal(signal.SIGINT)
# check if handlers are already installed (nested call)
current_sigint_handler = signal.getsignal(signal.SIGINT)

if current_sigint_handler is signal_receiver:
# already installed, this is a nested call - just yield
yield
return

# First call - install handlers
original_sigint_handler = current_sigint_handler
original_sigterm_handler = signal.getsignal(signal.SIGTERM)

# store original handlers for signal_receiver to use
_original_handlers[signal.SIGINT] = original_sigint_handler
_original_handlers[signal.SIGTERM] = original_sigterm_handler

try:
signal.signal(signal.SIGINT, signal_receiver)
signal.signal(signal.SIGTERM, signal_receiver)
yield
finally:
global _received_signal

_received_signal = 0
signal.signal(signal.SIGINT, original_sigint_handler)
signal.signal(signal.SIGTERM, original_sigterm_handler)
_clear_signals()

else:
logger.info("Running in daemon thread, signals not enabled")
yield
22 changes: 11 additions & 11 deletions dlt/destinations/job_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,19 @@ def __init__(

def run(self) -> None:
# update filepath, it will be in running jobs now
try:
if self._config.batch_size == 0:
# on batch size zero we only call the callable with the filename
self.call_callable_with_items(self._file_path)
else:
current_index = self._destination_state.get(self._storage_id, 0)
for batch in self.get_batches(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
self._destination_state[self._storage_id] = current_index
finally:
if self._config.batch_size == 0:
# on batch size zero we only call the callable with the filename
self.call_callable_with_items(self._file_path)
# save progress
commit_load_package_state()
else:
current_index = self._destination_state.get(self._storage_id, 0)
for batch in self.get_batches(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
self._destination_state[self._storage_id] = current_index
# save progress
commit_load_package_state()

def call_callable_with_items(self, items: TDataItems) -> None:
if not items:
Expand Down
3 changes: 3 additions & 0 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,10 @@ def _extract_single_source(
delta = left_gens - curr_gens
left_gens -= delta
collector.update("Resources", delta)

# kill extraction if signalled
signals.raise_if_signalled()

resource = source.resources.with_pipe(pipe_item.pipe)
item_format = get_data_item_format(pipe_item.item)
extractors[item_format].write_items(
Expand Down
2 changes: 1 addition & 1 deletion dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
if pending_exception:
raise pending_exception
break
# this will raise on signal
# this will wake up on os signal or if job thread pool signals
sleep(self._run_loop_sleep_duration)
except LoadClientJobFailed:
# the package is completed and skipped
Expand Down
1 change: 1 addition & 0 deletions dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def _normalize_chunk(
)
except StopIteration:
pass
# kill job if signalled
signals.raise_if_signalled()
return schema_update

Expand Down
10 changes: 4 additions & 6 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def extract(
# commit load packages with state
extract_step.commit_packages()
return self._get_step_info(extract_step)
except Exception as exc:
except (Exception, KeyboardInterrupt) as exc:
# emit step info
step_info = self._get_step_info(extract_step)
current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None
Expand Down Expand Up @@ -534,7 +534,7 @@ def normalize(self, workers: int = 1) -> NormalizeInfo:
with signals.delayed_signals():
runner.run_pool(normalize_step.config, normalize_step)
return self._get_step_info(normalize_step)
except Exception as n_ex:
except (Exception, KeyboardInterrupt) as n_ex:
step_info = self._get_step_info(normalize_step)
raise PipelineStepFailed(
self,
Expand Down Expand Up @@ -591,7 +591,7 @@ def load(
info: LoadInfo = self._get_step_info(load_step)
self._update_last_run_context()
return info
except Exception as l_ex:
except (Exception, KeyboardInterrupt) as l_ex:
step_info = self._get_step_info(load_step)
raise PipelineStepFailed(
self, "load", load_step.current_load_id, l_ex, step_info
Expand Down Expand Up @@ -684,8 +684,6 @@ def run(
Returns:
LoadInfo: Information on loaded data including the list of package ids and failed job statuses. Please not that `dlt` will not raise if a single job terminally fails. Such information is provided via LoadInfo.
"""

signals.raise_if_signalled()
self.activate()
self._set_destinations(
destination=destination, destination_credentials=credentials, staging=staging
Expand Down Expand Up @@ -866,7 +864,7 @@ def _sync_destination(
state["default_schema_name"] = new_default_schema_name
bump_pipeline_state_version_if_modified(state)
self._save_state(state)
except Exception as ex:
except (Exception, KeyboardInterrupt) as ex:
raise PipelineStepFailed(self, "sync", None, ex, None) from ex

def activate(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions docs/website/docs/reference/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ Please note the following:
process start method.
2. If you created the `Pipeline` object in the worker thread and you use it from another (i.e., the main thread),
call `pipeline.activate()` to inject the right context into the current thread.
3. Note how `with signals.delayed_signals():` was used to enable graceful shutdown of pipelines running in a thread pool.
:::


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def database_cursor_chunked():
def parallel_pipelines_asyncio_snippet() -> None:
# @@@DLT_SNIPPET_START parallel_pipelines
import asyncio
import dlt
from time import sleep
from concurrent.futures import ThreadPoolExecutor
import dlt
from dlt.common.runtime import signals

# create both asyncio and thread parallel resources
@dlt.resource
Expand Down Expand Up @@ -183,8 +184,11 @@ async def _run_async():
print("pipeline_1", results[0])
print("pipeline_2", results[1])

# load data
asyncio.run(_run_async())
# enable signal handling for graceful shutdowns - it is disabled for pipelines running
# in threads
with signals.delayed_signals():
# load data
asyncio.run(_run_async())
# activate pipelines before they are used
pipeline_1.activate()
assert pipeline_1.last_trace.last_normalize_info.row_counts["async_table"] == 10
Expand Down
Loading