Skip to content

Commit 7aff363

Browse files
authored
Use explicit timer in unit test (#1338)
* Use an explicit wait in a dataframe query during testing to check for keyboard interrupts * Add interrupt check when spawning futures * Update unit test to do four variantions of fast/slow queries and interrupt either collect or stream
1 parent a922967 commit 7aff363

File tree

2 files changed

+38
-166
lines changed

2 files changed

+38
-166
lines changed

python/tests/test_dataframe.py

Lines changed: 33 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
WindowFrame,
3838
column,
3939
literal,
40+
udf,
4041
)
4142
from datafusion import (
4243
col as df_col,
@@ -3190,179 +3191,42 @@ def test_fill_null_all_null_column(ctx):
31903191
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
31913192

31923193

3193-
def test_collect_interrupted():
3194-
"""Test that a long-running query can be interrupted with Ctrl-C.
3194+
@udf([pa.int64()], pa.int64(), "immutable")
3195+
def slow_udf(x: pa.Array) -> pa.Array:
3196+
# This must be longer than the check interval in wait_for_future
3197+
time.sleep(2.0)
3198+
return x
31953199

3196-
This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
3197-
exception in the main thread during a long-running query execution.
3198-
"""
3199-
# Create a context and a DataFrame with a query that will run for a while
3200-
ctx = SessionContext()
3201-
3202-
# Create a recursive computation that will run for some time
3203-
batches = []
3204-
for i in range(10):
3205-
batch = pa.RecordBatch.from_arrays(
3206-
[
3207-
pa.array(list(range(i * 1000, (i + 1) * 1000))),
3208-
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
3209-
],
3210-
names=["a", "b"],
3211-
)
3212-
batches.append(batch)
3213-
3214-
# Register tables
3215-
ctx.register_record_batches("t1", [batches])
3216-
ctx.register_record_batches("t2", [batches])
3217-
3218-
# Create a large join operation that will take time to process
3219-
df = ctx.sql("""
3220-
WITH t1_expanded AS (
3221-
SELECT
3222-
a,
3223-
b,
3224-
CAST(a AS DOUBLE) / 1.5 AS c,
3225-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
3226-
FROM t1
3227-
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
3228-
),
3229-
t2_expanded AS (
3230-
SELECT
3231-
a,
3232-
b,
3233-
CAST(a AS DOUBLE) * 2.5 AS e,
3234-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
3235-
FROM t2
3236-
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
3237-
)
3238-
SELECT
3239-
t1.a, t1.b, t1.c, t1.d,
3240-
t2.a AS a2, t2.b AS b2, t2.e, t2.f
3241-
FROM t1_expanded t1
3242-
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
3243-
WHERE t1.a > 100 AND t2.a > 100
3244-
""")
3245-
3246-
# Flag to track if the query was interrupted
3247-
interrupted = False
3248-
interrupt_error = None
3249-
main_thread = threading.main_thread()
3250-
3251-
# Shared flag to indicate query execution has started
3252-
query_started = threading.Event()
3253-
max_wait_time = 5.0 # Maximum wait time in seconds
3254-
3255-
# This function will be run in a separate thread and will raise
3256-
# KeyboardInterrupt in the main thread
3257-
def trigger_interrupt():
3258-
"""Poll for query start, then raise KeyboardInterrupt in the main thread"""
3259-
# Poll for query to start with small sleep intervals
3260-
start_time = time.time()
3261-
while not query_started.is_set():
3262-
time.sleep(0.1) # Small sleep between checks
3263-
if time.time() - start_time > max_wait_time:
3264-
msg = f"Query did not start within {max_wait_time} seconds"
3265-
raise RuntimeError(msg)
3266-
3267-
# Check if thread ID is available
3268-
thread_id = main_thread.ident
3269-
if thread_id is None:
3270-
msg = "Cannot get main thread ID"
3271-
raise RuntimeError(msg)
3272-
3273-
# Use ctypes to raise exception in main thread
3274-
exception = ctypes.py_object(KeyboardInterrupt)
3275-
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
3276-
ctypes.c_long(thread_id), exception
3277-
)
3278-
if res != 1:
3279-
# If res is 0, the thread ID was invalid
3280-
# If res > 1, we modified multiple threads
3281-
ctypes.pythonapi.PyThreadState_SetAsyncExc(
3282-
ctypes.c_long(thread_id), ctypes.py_object(0)
3283-
)
3284-
msg = "Failed to raise KeyboardInterrupt in main thread"
3285-
raise RuntimeError(msg)
3286-
3287-
# Start a thread to trigger the interrupt
3288-
interrupt_thread = threading.Thread(target=trigger_interrupt)
3289-
# we mark as daemon so the test process can exit even if this thread doesn't finish
3290-
interrupt_thread.daemon = True
3291-
interrupt_thread.start()
3292-
3293-
# Execute the query and expect it to be interrupted
3294-
try:
3295-
# Signal that we're about to start the query
3296-
query_started.set()
3297-
df.collect()
3298-
except KeyboardInterrupt:
3299-
interrupted = True
3300-
except Exception as e:
3301-
interrupt_error = e
3302-
3303-
# Assert that the query was interrupted properly
3304-
if not interrupted:
3305-
pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")
3306-
3307-
# Make sure the interrupt thread has finished
3308-
interrupt_thread.join(timeout=1.0)
33093200

3201+
@pytest.mark.parametrize(
3202+
("slow_query", "as_c_stream"),
3203+
[
3204+
(True, True),
3205+
(True, False),
3206+
(False, True),
3207+
(False, False),
3208+
],
3209+
)
3210+
def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 PLR0915
3211+
"""Ensure collection responds to ``KeyboardInterrupt`` signals.
33103212
3311-
def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915
3312-
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
3213+
This test issues a long-running query, and consumes the results via
3214+
either a collect() call or ``__arrow_c_stream__``. It raises
3215+
``KeyboardInterrupt`` in the main thread and verifies that the
3216+
process has interrupted.
33133217
3314-
Similar to ``test_collect_interrupted`` this test issues a long running
3315-
query, but consumes the results via ``__arrow_c_stream__``. It then raises
3316-
``KeyboardInterrupt`` in the main thread and verifies that the stream
3317-
iteration stops promptly with the appropriate exception.
3218+
The `slow_query` determines if the query itself is slow via a
3219+
UDF with a timeout or if it is a fast query that generates many
3220+
results so it takes a long time to iterate through them all.
33183221
"""
33193222

33203223
ctx = SessionContext()
3224+
df = ctx.sql("select * from generate_series(1, 1000000000000000000)")
3225+
if slow_query:
3226+
df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a")))
33213227

3322-
batches = []
3323-
for i in range(10):
3324-
batch = pa.RecordBatch.from_arrays(
3325-
[
3326-
pa.array(list(range(i * 1000, (i + 1) * 1000))),
3327-
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
3328-
],
3329-
names=["a", "b"],
3330-
)
3331-
batches.append(batch)
3332-
3333-
ctx.register_record_batches("t1", [batches])
3334-
ctx.register_record_batches("t2", [batches])
3335-
3336-
df = ctx.sql(
3337-
"""
3338-
WITH t1_expanded AS (
3339-
SELECT
3340-
a,
3341-
b,
3342-
CAST(a AS DOUBLE) / 1.5 AS c,
3343-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
3344-
FROM t1
3345-
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
3346-
),
3347-
t2_expanded AS (
3348-
SELECT
3349-
a,
3350-
b,
3351-
CAST(a AS DOUBLE) * 2.5 AS e,
3352-
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
3353-
FROM t2
3354-
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
3355-
)
3356-
SELECT
3357-
t1.a, t1.b, t1.c, t1.d,
3358-
t2.a AS a2, t2.b AS b2, t2.e, t2.f
3359-
FROM t1_expanded t1
3360-
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
3361-
WHERE t1.a > 100 AND t2.a > 100
3362-
"""
3363-
)
3364-
3365-
reader = pa.RecordBatchReader.from_stream(df)
3228+
if as_c_stream:
3229+
reader = pa.RecordBatchReader.from_stream(df)
33663230

33673231
read_started = threading.Event()
33683232
read_exception = []
@@ -3396,7 +3260,10 @@ def read_stream():
33963260
read_thread_id = threading.get_ident()
33973261
try:
33983262
read_started.set()
3399-
reader.read_all()
3263+
if as_c_stream:
3264+
reader.read_all()
3265+
else:
3266+
df.collect()
34003267
# If we get here, the read completed without interruption
34013268
read_exception.append(RuntimeError("Read completed without interruption"))
34023269
except KeyboardInterrupt:

src/utils.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ where
7777
let runtime: &Runtime = &get_tokio_runtime().0;
7878
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
7979

80+
// Some fast running processes that generate many `wait_for_future` calls like
81+
// PartitionedDataFrameStreamReader::next require checking for interrupts early
82+
py.run(cr"pass", None, None)?;
83+
py.check_signals()?;
84+
8085
py.detach(|| {
8186
runtime.block_on(async {
8287
tokio::pin!(fut);

0 commit comments

Comments
 (0)