Skip to content

Commit

Permalink
if the current_task().coro.cr_frame is in the stack ki_protection_ena…
Browse files Browse the repository at this point in the history
…bled is current_task()._ki_protected
  • Loading branch information
graingert committed Oct 15, 2024
1 parent 408d1ae commit 3593742
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 34 deletions.
13 changes: 13 additions & 0 deletions src/trio/_core/_ki.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import attrs

from .._util import is_main_thread
from ._run_context import GLOBAL_RUN_CONTEXT

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -170,6 +171,16 @@ def legacy_isasyncgenfunction(
# NB: according to the signal.signal docs, 'frame' can be None on entry to
# this function:
def ki_protection_enabled(frame: types.FrameType | None) -> bool:
try:
task = GLOBAL_RUN_CONTEXT.task
except AttributeError:
task_ki_protected = False
task_frame = None
else:
task_ki_protected = task._ki_protected
task_frame = task.coro.cr_frame
del task

while frame is not None:
try:
v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code]
Expand All @@ -179,6 +190,8 @@ def ki_protection_enabled(frame: types.FrameType | None) -> bool:
return bool(v)
if frame.f_code.co_name == "__del__":
return True
if frame is task_frame:
return task_ki_protected
frame = frame.f_back
return True

Expand Down
38 changes: 4 additions & 34 deletions src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import random
import select
import sys
import threading
import warnings
from collections import deque
from contextlib import AbstractAsyncContextManager, contextmanager, suppress
Expand Down Expand Up @@ -39,8 +38,9 @@
from ._entry_queue import EntryQueue, TrioToken
from ._exceptions import Cancelled, RunFinishedError, TrioInternalError
from ._instrumentation import Instruments
from ._ki import KIManager, disable_ki_protection, enable_ki_protection
from ._ki import KIManager, enable_ki_protection
from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER
from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT
from ._thread_cache import start_thread_soon
from ._traps import (
Abort,
Expand Down Expand Up @@ -1559,14 +1559,6 @@ def raise_cancel() -> NoReturn:
################################################################


class RunContext(threading.local):
runner: Runner
task: Task


GLOBAL_RUN_CONTEXT: Final = RunContext()


@attrs.frozen
class RunStatistics:
"""An object containing run-loop-level debugging information.
Expand Down Expand Up @@ -1670,22 +1662,6 @@ def in_main_thread() -> None:
start_thread_soon(get_events, deliver)


@enable_ki_protection
def run_with_ki_protection_enabled(f: Callable[[T], RetT], v: T) -> RetT:
try:
return f(v)
finally:
del v # for the case where f is coro.throw() and v is a (Base)Exception


@disable_ki_protection
def run_with_ki_protection_disabled(f: Callable[[T], RetT], v: T) -> RetT:
try:
return f(v)
finally:
del v # for the case where f is coro.throw() and v is a (Base)Exception


@attrs.define(eq=False)
class Runner:
clock: Clock
Expand Down Expand Up @@ -2730,11 +2706,6 @@ def unrolled_run(

next_send_fn = task._next_send_fn
next_send = task._next_send
run_with = (
run_with_ki_protection_enabled
if task._ki_protected
else run_with_ki_protection_disabled
)
task._next_send_fn = task._next_send = None
final_outcome: Outcome[Any] | None = None
try:
Expand All @@ -2747,17 +2718,16 @@ def unrolled_run(
# https://github.com/python/cpython/issues/108668
# So now we send in the Outcome object and unwrap it on the
# other side.
msg = task.context.run(run_with, next_send_fn, next_send)
msg = task.context.run(next_send_fn, next_send)
except StopIteration as stop_iteration:
final_outcome = Value(stop_iteration.value)
except BaseException as task_exc:
# Store for later, removing uninteresting top frames: 1
# frame we always remove, because it's this function
# another is the run_with
# catching it, and then in addition we remove however many
# more Context.run adds.
tb = task_exc.__traceback__
for _ in range(2 + CONTEXT_RUN_TB_FRAMES):
for _ in range(1 + CONTEXT_RUN_TB_FRAMES):
if tb is not None: # pragma: no branch
tb = tb.tb_next
final_outcome = Error(task_exc.with_traceback(tb))
Expand Down
15 changes: 15 additions & 0 deletions src/trio/_core/_run_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING, Final

if TYPE_CHECKING:
from ._run import Runner, Task


class RunContext(threading.local):
runner: Runner
task: Task


GLOBAL_RUN_CONTEXT: Final = RunContext()

0 comments on commit 3593742

Please sign in to comment.