Skip to content

Commit

Permalink
feat(local): interrupt the kernel on execution timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
afsu committed Aug 22, 2024
1 parent 1313e56 commit e7d95bf
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 22 deletions.
65 changes: 43 additions & 22 deletions src/pybox/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,10 @@ def __wait_for_execute_reply(self, msg_id: str, **kwargs) -> ExecutionResponse |
continue
# See <https://jupyter-client.readthedocs.io/en/latest/messaging.html#execution-results>
# error execution may have extra messages, for example a stream std error
response = ExecutionResponse.model_validate(shell_msg)
if response.content.status == "error":
raise CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)
return ExecutionResponse.model_validate(shell_msg)
except queue.Empty:
logger.warning("No shell message received.")
return None
else:
return response
self.__interrupt_kernel()
break

def __get_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
"""Retrieves output from a kernel.
Expand All @@ -70,6 +62,7 @@ def __get_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
CodeExecutionException: if the code execution fails
"""
result = None
error = None
while True:
# Poll the message
try:
Expand All @@ -94,10 +87,19 @@ def __get_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
# See <https://jupyter-client.readthedocs.io/en/stable/messaging.html#streams-stdout-stderr-etc>
if not result:
result = PyBoxOut(data={"text/plain": response.content.text})
elif response.msg_type == "error":
error = CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)

elif response.msg_type == "status": # noqa: SIM102
# According to the document <https://jupyter-client.readthedocs.io/en/latest/messaging.html#request-reply>
# The idle message will be published after processing the request and publishing associated IOPub messages
if response.content.execution_state == "idle":
if error is not None:
raise error
break
except queue.Empty:
logger.warning("No iopub message received.")
Expand All @@ -122,18 +124,10 @@ async def __await_for_execute_reply(self, msg_id: str, **kwargs) -> ExecutionRes
continue
# See <https://jupyter-client.readthedocs.io/en/latest/messaging.html#execution-results>
# error execution may have extra messages, for example a stream std error
response = ExecutionResponse.model_validate(shell_msg)
if response.content.status == "error":
raise CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)
return ExecutionResponse.model_validate(shell_msg)
except queue.Empty:
logger.warning("No shell message received.")
return None
else:
return response
self.__interrupt_kernel()
break

async def __aget_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
"""Retrieves output from a kernel asynchronously.
Expand All @@ -148,6 +142,7 @@ async def __aget_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
CodeExecutionException: if the code execution fails
"""
result = None
error = None
while True:
# Poll the message
try:
Expand All @@ -174,16 +169,42 @@ async def __aget_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
# See <https://jupyter-client.readthedocs.io/en/stable/messaging.html#streams-stdout-stderr-etc>
if not result:
result = PyBoxOut(data={"text/plain": response.content.text})
elif response.msg_type == "error":
error = CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)
elif response.msg_type == "status": # noqa: SIM102
# According to the document <https://jupyter-client.readthedocs.io/en/latest/messaging.html#request-reply>
# The idle message will be published after processing the request and publishing associated IOPub messages
if response.content.execution_state == "idle":
if error is not None:
raise error
break
except queue.Empty:
logger.warning("No iopub message received.")
break
return result

def __interrupt_kernel(self) -> None:
"""send an interrupt message to the kernel."""
try:
interrupt_msg = self.client.session.msg("interrupt_request", content={})
self.client.control_channel.send(interrupt_msg)
control_msg = self.client.get_control_msg(timeout=5)
# TODO: Do you need to determine whether the parent id is equal to the interrupt message id?
# See <https://jupyter-client.readthedocs.io/en/latest/messaging.html#kernel-interrupt>
if control_msg["msg_type"] == "interrupt_reply":
status = control_msg["content"]["status"]
if status == "ok":
logger.info("Kernel %s interrupt signal sent successfully.", self.kernel_id)
else:
logger.warning("Kernel %s interrupt signal sent failed: %s", self.kernel_id, status)
except Exception as e: # noqa: BLE001
# TODO: What should I do if sending an interrupt message times out or fails?
logger.warning("Failed to send interrupt message to kernel %s: %s", self.kernel_id, e)


class LocalPyBoxManager(BasePyBoxManager):
def __init__(
Expand Down
46 changes: 46 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,49 @@ async def test_not_output_async(local_box: LocalPyBox):
test_code = "a = 1"
res = await local_box.arun(code=test_code)
assert res is None


def test_execute_timeout(local_box: LocalPyBox):
timeout_code = """import time
time.sleep(6)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
local_box.run(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"


@pytest.mark.asyncio
async def test_execute_timeout_async(local_box: LocalPyBox):
timeout_code = """import time
time.sleep(6)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
await local_box.arun(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"


def test_interrupt_kernel(local_box: LocalPyBox):
code = "a = 1"
local_box.run(code=code)

timeout_code = """import time
time.sleep(6)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
local_box.run(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"

res = local_box.run(code="print(a)")
assert res.text == "1\n"


@pytest.mark.asyncio
async def test_interrupt_kernel_async(local_box: LocalPyBox):
code = "a = 1"
await local_box.arun(code=code)

timeout_code = """import time
time.sleep(10)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
await local_box.arun(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"

res = await local_box.arun(code="print(a)")
assert res.text == "1\n"

0 comments on commit e7d95bf

Please sign in to comment.