|
37 | 37 | WindowFrame, |
38 | 38 | column, |
39 | 39 | literal, |
| 40 | + udf, |
40 | 41 | ) |
41 | 42 | from datafusion import ( |
42 | 43 | col as df_col, |
@@ -3190,179 +3191,42 @@ def test_fill_null_all_null_column(ctx): |
3190 | 3191 | assert result.column(1).to_pylist() == ["filled", "filled", "filled"] |
3191 | 3192 |
|
3192 | 3193 |
|
3193 | | -def test_collect_interrupted(): |
3194 | | - """Test that a long-running query can be interrupted with Ctrl-C. |
| 3194 | +@udf([pa.int64()], pa.int64(), "immutable") |
| 3195 | +def slow_udf(x: pa.Array) -> pa.Array: |
| 3196 | + # This must be longer than the check interval in wait_for_future |
| 3197 | + time.sleep(2.0) |
| 3198 | + return x |
3195 | 3199 |
|
3196 | | - This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt |
3197 | | - exception in the main thread during a long-running query execution. |
3198 | | - """ |
3199 | | - # Create a context and a DataFrame with a query that will run for a while |
3200 | | - ctx = SessionContext() |
3201 | | - |
3202 | | - # Create a recursive computation that will run for some time |
3203 | | - batches = [] |
3204 | | - for i in range(10): |
3205 | | - batch = pa.RecordBatch.from_arrays( |
3206 | | - [ |
3207 | | - pa.array(list(range(i * 1000, (i + 1) * 1000))), |
3208 | | - pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), |
3209 | | - ], |
3210 | | - names=["a", "b"], |
3211 | | - ) |
3212 | | - batches.append(batch) |
3213 | | - |
3214 | | - # Register tables |
3215 | | - ctx.register_record_batches("t1", [batches]) |
3216 | | - ctx.register_record_batches("t2", [batches]) |
3217 | | - |
3218 | | - # Create a large join operation that will take time to process |
3219 | | - df = ctx.sql(""" |
3220 | | - WITH t1_expanded AS ( |
3221 | | - SELECT |
3222 | | - a, |
3223 | | - b, |
3224 | | - CAST(a AS DOUBLE) / 1.5 AS c, |
3225 | | - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d |
3226 | | - FROM t1 |
3227 | | - CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) |
3228 | | - ), |
3229 | | - t2_expanded AS ( |
3230 | | - SELECT |
3231 | | - a, |
3232 | | - b, |
3233 | | - CAST(a AS DOUBLE) * 2.5 AS e, |
3234 | | - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f |
3235 | | - FROM t2 |
3236 | | - CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) |
3237 | | - ) |
3238 | | - SELECT |
3239 | | - t1.a, t1.b, t1.c, t1.d, |
3240 | | - t2.a AS a2, t2.b AS b2, t2.e, t2.f |
3241 | | - FROM t1_expanded t1 |
3242 | | - JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 |
3243 | | - WHERE t1.a > 100 AND t2.a > 100 |
3244 | | - """) |
3245 | | - |
3246 | | - # Flag to track if the query was interrupted |
3247 | | - interrupted = False |
3248 | | - interrupt_error = None |
3249 | | - main_thread = threading.main_thread() |
3250 | | - |
3251 | | - # Shared flag to indicate query execution has started |
3252 | | - query_started = threading.Event() |
3253 | | - max_wait_time = 5.0 # Maximum wait time in seconds |
3254 | | - |
3255 | | - # This function will be run in a separate thread and will raise |
3256 | | - # KeyboardInterrupt in the main thread |
3257 | | - def trigger_interrupt(): |
3258 | | - """Poll for query start, then raise KeyboardInterrupt in the main thread""" |
3259 | | - # Poll for query to start with small sleep intervals |
3260 | | - start_time = time.time() |
3261 | | - while not query_started.is_set(): |
3262 | | - time.sleep(0.1) # Small sleep between checks |
3263 | | - if time.time() - start_time > max_wait_time: |
3264 | | - msg = f"Query did not start within {max_wait_time} seconds" |
3265 | | - raise RuntimeError(msg) |
3266 | | - |
3267 | | - # Check if thread ID is available |
3268 | | - thread_id = main_thread.ident |
3269 | | - if thread_id is None: |
3270 | | - msg = "Cannot get main thread ID" |
3271 | | - raise RuntimeError(msg) |
3272 | | - |
3273 | | - # Use ctypes to raise exception in main thread |
3274 | | - exception = ctypes.py_object(KeyboardInterrupt) |
3275 | | - res = ctypes.pythonapi.PyThreadState_SetAsyncExc( |
3276 | | - ctypes.c_long(thread_id), exception |
3277 | | - ) |
3278 | | - if res != 1: |
3279 | | - # If res is 0, the thread ID was invalid |
3280 | | - # If res > 1, we modified multiple threads |
3281 | | - ctypes.pythonapi.PyThreadState_SetAsyncExc( |
3282 | | - ctypes.c_long(thread_id), ctypes.py_object(0) |
3283 | | - ) |
3284 | | - msg = "Failed to raise KeyboardInterrupt in main thread" |
3285 | | - raise RuntimeError(msg) |
3286 | | - |
3287 | | - # Start a thread to trigger the interrupt |
3288 | | - interrupt_thread = threading.Thread(target=trigger_interrupt) |
3289 | | - # we mark as daemon so the test process can exit even if this thread doesn't finish |
3290 | | - interrupt_thread.daemon = True |
3291 | | - interrupt_thread.start() |
3292 | | - |
3293 | | - # Execute the query and expect it to be interrupted |
3294 | | - try: |
3295 | | - # Signal that we're about to start the query |
3296 | | - query_started.set() |
3297 | | - df.collect() |
3298 | | - except KeyboardInterrupt: |
3299 | | - interrupted = True |
3300 | | - except Exception as e: |
3301 | | - interrupt_error = e |
3302 | | - |
3303 | | - # Assert that the query was interrupted properly |
3304 | | - if not interrupted: |
3305 | | - pytest.fail(f"Query was not interrupted; got error: {interrupt_error}") |
3306 | | - |
3307 | | - # Make sure the interrupt thread has finished |
3308 | | - interrupt_thread.join(timeout=1.0) |
3309 | 3200 |
|
| 3201 | +@pytest.mark.parametrize( |
| 3202 | + ("slow_query", "as_c_stream"), |
| 3203 | + [ |
| 3204 | + (True, True), |
| 3205 | + (True, False), |
| 3206 | + (False, True), |
| 3207 | + (False, False), |
| 3208 | + ], |
| 3209 | +) |
| 3210 | +def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 PLR0915 |
| 3211 | + """Ensure collection responds to ``KeyboardInterrupt`` signals. |
3310 | 3212 |
|
3311 | | -def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915 |
3312 | | - """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals. |
| 3213 | + This test issues a long-running query, and consumes the results via |
| 3214 | + either a collect() call or ``__arrow_c_stream__``. It raises |
| 3215 | + ``KeyboardInterrupt`` in the main thread and verifies that the |
| 3216 | + process has interrupted. |
3313 | 3217 |
|
3314 | | - Similar to ``test_collect_interrupted`` this test issues a long running |
3315 | | - query, but consumes the results via ``__arrow_c_stream__``. It then raises |
3316 | | - ``KeyboardInterrupt`` in the main thread and verifies that the stream |
3317 | | - iteration stops promptly with the appropriate exception. |
| 3218 | + The `slow_query` determines if the query itself is slow via a |
| 3219 | + UDF with a timeout or if it is a fast query that generates many |
| 3220 | + results so it takes a long time to iterate through them all. |
3318 | 3221 | """ |
3319 | 3222 |
|
3320 | 3223 | ctx = SessionContext() |
| 3224 | + df = ctx.sql("select * from generate_series(1, 1000000000000000000)") |
| 3225 | + if slow_query: |
| 3226 | + df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a"))) |
3321 | 3227 |
|
3322 | | - batches = [] |
3323 | | - for i in range(10): |
3324 | | - batch = pa.RecordBatch.from_arrays( |
3325 | | - [ |
3326 | | - pa.array(list(range(i * 1000, (i + 1) * 1000))), |
3327 | | - pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), |
3328 | | - ], |
3329 | | - names=["a", "b"], |
3330 | | - ) |
3331 | | - batches.append(batch) |
3332 | | - |
3333 | | - ctx.register_record_batches("t1", [batches]) |
3334 | | - ctx.register_record_batches("t2", [batches]) |
3335 | | - |
3336 | | - df = ctx.sql( |
3337 | | - """ |
3338 | | - WITH t1_expanded AS ( |
3339 | | - SELECT |
3340 | | - a, |
3341 | | - b, |
3342 | | - CAST(a AS DOUBLE) / 1.5 AS c, |
3343 | | - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d |
3344 | | - FROM t1 |
3345 | | - CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) |
3346 | | - ), |
3347 | | - t2_expanded AS ( |
3348 | | - SELECT |
3349 | | - a, |
3350 | | - b, |
3351 | | - CAST(a AS DOUBLE) * 2.5 AS e, |
3352 | | - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f |
3353 | | - FROM t2 |
3354 | | - CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) |
3355 | | - ) |
3356 | | - SELECT |
3357 | | - t1.a, t1.b, t1.c, t1.d, |
3358 | | - t2.a AS a2, t2.b AS b2, t2.e, t2.f |
3359 | | - FROM t1_expanded t1 |
3360 | | - JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 |
3361 | | - WHERE t1.a > 100 AND t2.a > 100 |
3362 | | - """ |
3363 | | - ) |
3364 | | - |
3365 | | - reader = pa.RecordBatchReader.from_stream(df) |
| 3228 | + if as_c_stream: |
| 3229 | + reader = pa.RecordBatchReader.from_stream(df) |
3366 | 3230 |
|
3367 | 3231 | read_started = threading.Event() |
3368 | 3232 | read_exception = [] |
@@ -3396,7 +3260,10 @@ def read_stream(): |
3396 | 3260 | read_thread_id = threading.get_ident() |
3397 | 3261 | try: |
3398 | 3262 | read_started.set() |
3399 | | - reader.read_all() |
| 3263 | + if as_c_stream: |
| 3264 | + reader.read_all() |
| 3265 | + else: |
| 3266 | + df.collect() |
3400 | 3267 | # If we get here, the read completed without interruption |
3401 | 3268 | read_exception.append(RuntimeError("Read completed without interruption")) |
3402 | 3269 | except KeyboardInterrupt: |
|
0 commit comments