Skip to content

Commit

Permalink
Fix: there can be only one comm_manager (ipython#1049)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
Fixes ipython#1043
  • Loading branch information
maartenbreddels authored Dec 7, 2022
1 parent 4dc3033 commit 2c80e6c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 7 deletions.
3 changes: 3 additions & 0 deletions ipykernel/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def _default_comm_id(self):
def __init__(self, *args, **kwargs):
# Comm takes positional arguments, LoggingConfigurable does not, so we explicitly forward arguments
traitlets.config.LoggingConfigurable.__init__(self, **kwargs)
for name in self.trait_names():
if name in kwargs:
kwargs.pop(name)
BaseComm.__init__(self, *args, **kwargs)


Expand Down
13 changes: 12 additions & 1 deletion ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import getpass
import signal
import sys
import threading
import typing as t
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -46,9 +47,19 @@ def _create_comm(*args, **kwargs):
return BaseComm(*args, **kwargs)


# there can only be one comm manager in a ipykernel process
_comm_lock = threading.Lock()
_comm_manager: t.Optional[CommManager] = None


def _get_comm_manager(*args, **kwargs):
"""Create a new CommManager."""
return CommManager(*args, **kwargs)
global _comm_manager
if _comm_manager is None:
with _comm_lock:
if _comm_manager is None:
_comm_manager = CommManager(*args, **kwargs)
return _comm_manager


comm.create_comm = _create_comm
Expand Down
88 changes: 84 additions & 4 deletions ipykernel/tests/test_comm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,87 @@
from ipykernel.comm import Comm
from ipykernel.comm import Comm, CommManager
from ipykernel.ipkernel import IPythonKernel


async def test_comm(kernel):
c = Comm()
c.kernel = kernel # type:ignore
def test_comm(kernel):
manager = CommManager(kernel=kernel)
kernel.comm_manager = manager

c = Comm(kernel=kernel)
msgs = []

def on_close(msg):
msgs.append(msg)

def on_message(msg):
msgs.append(msg)

c.publish_msg("foo")
c.open({})
c.on_msg(on_message)
c.on_close(on_close)
c.handle_msg({})
c.handle_close({})
c.close()
assert len(msgs) == 2


def test_comm_manager(kernel):
manager = CommManager(kernel=kernel)
msgs = []

def foo(comm, msg):
msgs.append(msg)
comm.close()

def fizz(comm, msg):
raise RuntimeError('hi')

def on_close(msg):
msgs.append(msg)

def on_msg(msg):
msgs.append(msg)

manager.register_target("foo", foo)
manager.register_target("fizz", fizz)

kernel.comm_manager = manager
comm = Comm()
comm.on_msg(on_msg)
comm.on_close(on_close)
manager.register_comm(comm)

assert manager.get_comm(comm.comm_id) == comm
assert manager.get_comm('foo') is None

msg = dict(content=dict(comm_id=comm.comm_id, target_name='foo'))
manager.comm_open(None, None, msg)
assert len(msgs) == 1
msg['content']['target_name'] = 'bar'
manager.comm_open(None, None, msg)
assert len(msgs) == 1
msg = dict(content=dict(comm_id=comm.comm_id, target_name='fizz'))
manager.comm_open(None, None, msg)
assert len(msgs) == 1

manager.register_comm(comm)
assert manager.get_comm(comm.comm_id) == comm
msg = dict(content=dict(comm_id=comm.comm_id))
manager.comm_msg(None, None, msg)
assert len(msgs) == 2
msg['content']['comm_id'] = 'foo'
manager.comm_msg(None, None, msg)
assert len(msgs) == 2

manager.register_comm(comm)
assert manager.get_comm(comm.comm_id) == comm
msg = dict(content=dict(comm_id=comm.comm_id))
manager.comm_close(None, None, msg)
assert len(msgs) == 3

assert comm._closed


def test_comm_in_manager(ipkernel: IPythonKernel) -> None:
comm = Comm()
assert comm.comm_id in ipkernel.comm_manager.comms
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ requires-python = ">=3.8"
dependencies = [
"debugpy>=1.0",
"ipython>=7.23.1",
"comm>=0.1",
"traitlets>=5.1.0",
"comm>=0.1.1",
"traitlets>=5.4.0",
"jupyter_client>=6.1.12",
"tornado>=6.1",
"matplotlib-inline>=0.1",
Expand Down

0 comments on commit 2c80e6c

Please sign in to comment.