Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,8 @@ def _plot_mpl_stc(stc, subject=None, surface='inflated', hemi='lh',
return fig


def link_brains(brains, time=True, camera=False, colorbar=True):
def link_brains(brains, time=True, camera=False, colorbar=True,
picking=False):
"""Plot multiple SourceEstimate objects with PyVista.

Parameters
Expand All @@ -1534,6 +1535,8 @@ def link_brains(brains, time=True, camera=False, colorbar=True):
If True, link the camera controls. Defaults to False.
colorbar : bool
If True, link the colorbar controllers. Defaults to True.
picking : bool
If True, link the vertices picked with the mouse. Defaults to False.
"""
from .backends.renderer import _get_3d_backend
if _get_3d_backend() != 'pyvista':
Expand All @@ -1553,7 +1556,13 @@ def link_brains(brains, time=True, camera=False, colorbar=True):
raise TypeError("Expected type is Brain but"
" {} was given.".format(type(brain)))
# link brains properties
_LinkViewer(brains, time, camera, colorbar)
_LinkViewer(
brains=brains,
time=time,
camera=camera,
colorbar=colorbar,
picking=picking,
)


def _triage_stc(stc, src, surface, backend_name, kind='scalar'):
Expand Down
75 changes: 60 additions & 15 deletions mne/viz/_brain/_timeviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def __init__(self, brain, show_traces=False):
self.act_data_smooth = {key: (None, None) for key in all_keys}
self.color_cycle = None
self.picked_points = {key: list() for key in all_keys}
self.pick_table = dict()
self._mouse_no_mvt = -1
self.icons = dict()
self.actions = dict()
Expand Down Expand Up @@ -1133,6 +1134,9 @@ def on_pick(self, vtk_picker, event):
self.add_point(hemi, mesh, vertex_id)

def add_point(self, hemi, mesh, vertex_id):
# skip if the wrong hemi is selected
if self.act_data_smooth[hemi][0] is None:
return
from ..backends._pyvista import _sphere
color = next(self.color_cycle)
line = self.plot_time_course(hemi, vertex_id, color)
Expand Down Expand Up @@ -1184,34 +1188,40 @@ def add_point(self, hemi, mesh, vertex_id):
sphere._actors = actors
sphere._color = color
sphere._vertex_id = vertex_id
sphere._spheres = spheres

self.picked_points[hemi].append(vertex_id)
self._spheres.extend(spheres)
self.pick_table[vertex_id] = spheres

def remove_point(self, mesh):
if mesh._spheres is None:
return # already removed
mesh._line.remove()
vertex_id = mesh._vertex_id
if vertex_id not in self.pick_table:
return

hemi = mesh._hemi
color = mesh._color
spheres = self.pick_table[vertex_id]
spheres[0]._line.remove()
self.mpl_canvas.update_plot()
self.picked_points[mesh._hemi].remove(mesh._vertex_id)
self.picked_points[hemi].remove(vertex_id)

with warnings.catch_warnings(record=True):
# We intentionally ignore these in case we have traversed the
# entire color cycle
warnings.simplefilter('ignore')
self.color_cycle.restore(mesh._color)
# remove all actors
self.plotter.remove_actor(mesh._actors)
mesh._actors = None
# remove all meshes from sphere list
for sphere in list(mesh._spheres): # includes itself, so copy
self.color_cycle.restore(color)
for sphere in spheres:
# remove all actors
self.plotter.remove_actor(sphere._actors)
sphere._actors = None
self._spheres.pop(self._spheres.index(sphere))
sphere._spheres = sphere._actors = None
self.pick_table.pop(vertex_id)

def clear_points(self):
for sphere in list(self._spheres): # will remove itself, so copy
self.remove_point(sphere)
assert sum(len(v) for v in self.picked_points.values()) == 0
assert len(self.pick_table) == 0
assert len(self._spheres) == 0

def plot_time_course(self, hemi, vertex_id, color):
Expand Down Expand Up @@ -1348,7 +1358,8 @@ def clean(self):
class _LinkViewer(object):
"""Class to link multiple _TimeViewer objects."""

def __init__(self, brains, time=True, camera=False, colorbar=True):
def __init__(self, brains, time=True, camera=False, colorbar=True,
picking=False):
self.brains = brains
self.time_viewers = [brain.time_viewer for brain in brains]

Expand Down Expand Up @@ -1383,13 +1394,47 @@ def __init__(self, brains, time=True, camera=False, colorbar=True):
self.toggle_playback)

# link time course canvas
def _func(*args, **kwargs):
def _time_func(*args, **kwargs):
for time_viewer in self.time_viewers:
time_viewer.time_call(*args, **kwargs)

for time_viewer in self.time_viewers:
if time_viewer.show_traces:
time_viewer.mpl_canvas.time_func = _func
time_viewer.mpl_canvas.time_func = _time_func

if picking:
def _func_add(*args, **kwargs):
for time_viewer in self.time_viewers:
time_viewer._add_point(*args, **kwargs)
time_viewer.plotter.update()

def _func_remove(*args, **kwargs):
for time_viewer in self.time_viewers:
time_viewer._remove_point(*args, **kwargs)

# save initial picked points
initial_points = dict()
for hemi in ('lh', 'rh'):
initial_points[hemi] = set()
for time_viewer in self.time_viewers:
initial_points[hemi] |= \
set(time_viewer.picked_points[hemi])

# link the viewers
for time_viewer in self.time_viewers:
time_viewer.clear_points()
time_viewer._add_point = time_viewer.add_point
time_viewer.add_point = _func_add
time_viewer._remove_point = time_viewer.remove_point
time_viewer.remove_point = _func_remove

# link the initial points
leader = self.time_viewers[0] # select a time_viewer as leader
for hemi in initial_points.keys():
if hemi in time_viewer.brain._hemi_meshes:
mesh = time_viewer.brain._hemi_meshes[hemi]
for vertex_id in initial_points[hemi]:
leader.add_point(hemi, mesh, vertex_id)

if colorbar:
for slider_name in ('min', 'mid', 'max'):
Expand Down
1 change: 1 addition & 0 deletions mne/viz/_brain/tests/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def test_brain_linkviewer(renderer_interactive, travis_macos):
[brain_data],
time=True,
camera=True,
picking=True,
)
link_viewer.set_time_point(value=0)
link_viewer.set_playback_speed(value=0.1)
Expand Down