Skip to content

Commit

Permalink
feat(local): interrupt the kernel on execution timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
afsu committed Aug 22, 2024
1 parent 1313e56 commit aaef478
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 30 deletions.
79 changes: 49 additions & 30 deletions src/pybox/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,14 @@ def __wait_for_execute_reply(self, msg_id: str, **kwargs) -> ExecutionResponse |
while True:
try:
shell_msg = self.client.get_shell_msg(**kwargs)
if (shell_msg["parent_header"]["msg_id"] != msg_id) or (shell_msg["msg_type"] != "execute_reply"):
continue
# See <https://jupyter-client.readthedocs.io/en/latest/messaging.html#execution-results>
# error execution may have extra messages, for example a stream std error
response = ExecutionResponse.model_validate(shell_msg)
if response.content.status == "error":
raise CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)
if (shell_msg["parent_header"]["msg_id"] != msg_id) or (shell_msg["msg_type"] != "execute_reply"):
continue
except queue.Empty:
logger.warning("No shell message received.")
self.__interrupt_kernel()
return None
else:
return response
return ExecutionResponse.model_validate(shell_msg)

def __get_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
"""Retrieves output from a kernel.
Expand All @@ -70,6 +62,7 @@ def __get_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
CodeExecutionException: if the code execution fails
"""
result = None
error = None
while True:
# Poll the message
try:
Expand All @@ -94,15 +87,23 @@ def __get_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
# See <https://jupyter-client.readthedocs.io/en/stable/messaging.html#streams-stdout-stderr-etc>
if not result:
result = PyBoxOut(data={"text/plain": response.content.text})
elif response.msg_type == "error":
error = CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)

elif response.msg_type == "status": # noqa: SIM102
# According to the document <https://jupyter-client.readthedocs.io/en/latest/messaging.html#request-reply>
# The idle message will be published after processing the request and publishing associated IOPub messages
if response.content.execution_state == "idle":
break
if error is not None:
raise error
return result
except queue.Empty:
logger.warning("No iopub message received.")
break
return result
return result

async def arun(self, code: str, timeout: int = 60) -> PyBoxOut | None:
if not self.client.channels_running:
Expand All @@ -118,22 +119,14 @@ async def __await_for_execute_reply(self, msg_id: str, **kwargs) -> ExecutionRes
shell_msg = await self.client._async_get_shell_msg( # noqa: SLF001
**kwargs
)
if (shell_msg["parent_header"]["msg_id"] != msg_id) or (shell_msg["msg_type"] != "execute_reply"):
continue
# See <https://jupyter-client.readthedocs.io/en/latest/messaging.html#execution-results>
# error execution may have extra messages, for example a stream std error
response = ExecutionResponse.model_validate(shell_msg)
if response.content.status == "error":
raise CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)
if (shell_msg["parent_header"]["msg_id"] != msg_id) or (shell_msg["msg_type"] != "execute_reply"):
continue
except queue.Empty:
logger.warning("No shell message received.")
self.__interrupt_kernel()
return None
else:
return response
return ExecutionResponse.model_validate(shell_msg)

async def __aget_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
"""Retrieves output from a kernel asynchronously.
Expand All @@ -148,6 +141,7 @@ async def __aget_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
CodeExecutionException: if the code execution fails
"""
result = None
error = None
while True:
# Poll the message
try:
Expand All @@ -174,15 +168,40 @@ async def __aget_kernel_output(self, msg_id: str, **kwargs) -> PyBoxOut | None:
# See <https://jupyter-client.readthedocs.io/en/stable/messaging.html#streams-stdout-stderr-etc>
if not result:
result = PyBoxOut(data={"text/plain": response.content.text})
elif response.msg_type == "error":
error = CodeExecutionError(
ename=response.content.ename,
evalue=response.content.evalue,
traceback=response.content.traceback,
)
elif response.msg_type == "status": # noqa: SIM102
# According to the document <https://jupyter-client.readthedocs.io/en/latest/messaging.html#request-reply>
# The idle message will be published after processing the request and publishing associated IOPub messages
if response.content.execution_state == "idle":
break
if error is not None:
raise error
return result
except queue.Empty:
logger.warning("No iopub message received.")
break
return result
return result

def __interrupt_kernel(self) -> None:
"""send an interrupt message to the kernel."""
try:
interrupt_msg = self.client.session.msg("interrupt_request", content={})
self.client.control_channel.send(interrupt_msg)
control_msg = self.client.get_control_msg(timeout=5)
# TODO: Do you need to determine whether the parent id is equal to the interrupt message id?
# See <https://jupyter-client.readthedocs.io/en/latest/messaging.html#kernel-interrupt>
if control_msg["msg_type"] == "interrupt_reply":
status = control_msg["content"]["status"]
if status == "ok":
logger.info("Kernel %s interrupt signal sent successfully.", self.kernel_id)
else:
logger.warning("Kernel %s interrupt signal sent failed: %s", self.kernel_id, status)
except Exception as e: # noqa: BLE001
# TODO: What should I do if sending an interrupt message times out or fails?
logger.warning("Failed to send interrupt message to kernel %s: %s", self.kernel_id, e)


class LocalPyBoxManager(BasePyBoxManager):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,49 @@ async def test_not_output_async(local_box: LocalPyBox):
test_code = "a = 1"
res = await local_box.arun(code=test_code)
assert res is None


def test_execute_timeout(local_box: LocalPyBox):
timeout_code = """import time
time.sleep(6)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
local_box.run(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"


@pytest.mark.asyncio
async def test_execute_timeout_async(local_box: LocalPyBox):
timeout_code = """import time
time.sleep(6)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
await local_box.arun(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"


def test_interrupt_kernel(local_box: LocalPyBox):
code = "a = 1"
local_box.run(code=code)

timeout_code = """import time
time.sleep(6)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
local_box.run(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"

res = local_box.run(code="print(a)")
assert res.text == "1\n"


@pytest.mark.asyncio
async def test_interrupt_kernel_async(local_box: LocalPyBox):
code = "a = 1"
await local_box.arun(code=code)

timeout_code = """import time
time.sleep(10)"""
with pytest.raises(CodeExecutionError) as exc_info: # noqa: PT012
await local_box.arun(code=timeout_code, timeout=5)
assert exc_info.value.args[0] == "KeyboardInterrupt"

res = await local_box.arun(code="print(a)")
assert res.text == "1\n"

0 comments on commit aaef478

Please sign in to comment.