Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ Bugs

- The RMS trace shown in the time viewer of `~mne.SourceEstimate` plots is now correctly labeled as ``RMS`` (was ``GFP`` before) (:gh:`8965` by `Richard Höchenberger`_)

- Fix bug with :func:`mne.SourceEstimate.plot` and related functions where the scalars were not interactively updated properly (:gh:`8985` by `Eric Larson`_)

- Fix bug with mne.channels.find_ch_adjacency() returning wrong adjacency for Neuromag122-Data (:gh:`8891` by `Martin Schulz`_)

- Fix :func:`mne.read_dipole` yielding :class:`mne.Dipole` objects that could not be indexed (:gh:`8963` by `Marijn van Vliet`_)
Expand Down
8 changes: 7 additions & 1 deletion mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,16 @@ def brain_gc(request):
yield
return
from mne.viz import Brain
_assert_no_instances(Brain, 'before')
ignore = set(id(o) for o in gc.get_objects())
yield
close_func()
# no need to warn if the test itself failed, pytest-harvest helps us here
try:
outcome = request.node.harvest_rep_call
except Exception:
outcome = 'failed'
if outcome != 'passed':
return
_assert_no_instances(Brain, 'after')
# We only check VTK for PyVista -- Mayavi/PySurfer is not as strict
objs = gc.get_objects()
Expand Down
2 changes: 1 addition & 1 deletion mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def objective(x):
stc_max, directions = stc.project('pca')
flips = np.sign(np.sum(directions * want_nn, axis=1, keepdims=True))
directions *= flips
assert_allclose(directions, want_nn, atol=1e-6)
assert_allclose(directions, want_nn, atol=2e-6)


@testing.requires_testing_data
Expand Down
132 changes: 89 additions & 43 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .view import views_dicts, _lh_views_dict
from .mplcanvas import MplCanvas
from .callback import (ShowView, TimeCallBack, SmartCallBack, Widget,
BumpColorbarPoints, UpdateColorbarScale)
UpdateLUT, UpdateColorbarScale)

from ..utils import _show_help, _get_color_list, concatenate_images
from .._3d import _process_clim, _handle_time, _check_views
Expand Down Expand Up @@ -54,39 +54,46 @@ def safe_event(fun, *args, **kwargs):


class _Overlay(object):
def __init__(self, scalars, colormap, rng, opacity):
def __init__(self, scalars, colormap, rng, opacity, name):
self._scalars = scalars
self._colormap = colormap
assert rng is not None
self._rng = rng
self._opacity = opacity
self._name = name

def to_colors(self):
from .._3d import _get_cmap
from matplotlib.colors import ListedColormap

if isinstance(self._colormap, str):
kind = self._colormap
cmap = _get_cmap(self._colormap)
else:
cmap = ListedColormap(self._colormap / 255.)

def diff(x):
return np.max(x) - np.min(x)

def norm(x, rng=None):
if rng is None:
rng = [np.min(x), np.max(x)]
return (x - rng[0]) / (rng[1] - rng[0])
kind = str(type(self._colormap))
logger.debug(
f'Color mapping {repr(self._name)} with {kind} '
f'colormap and range {self._rng}')

rng = self._rng
scalars = self._scalars
if diff(scalars) != 0:
scalars = norm(scalars, rng)
assert rng is not None
scalars = _norm(self._scalars, rng)

colors = cmap(scalars)
if self._opacity is not None:
colors[:, 3] *= self._opacity
return colors


def _norm(x, rng):
if rng[0] == rng[1]:
factor = 1 if rng[0] == 0 else 1e-6 * rng[0]
else:
factor = rng[1] - rng[0]
return (x - rng[0]) / factor


class _LayeredMesh(object):
def __init__(self, renderer, vertices, triangles, normals):
self._renderer = renderer
Expand Down Expand Up @@ -149,7 +156,8 @@ def add_overlay(self, scalars, colormap, rng, opacity, name):
scalars=scalars,
colormap=colormap,
rng=rng,
opacity=opacity
opacity=opacity,
name=name,
)
self._overlays[name] = overlay
colors = overlay.to_colors()
Expand Down Expand Up @@ -193,7 +201,7 @@ def _clean(self):
self._renderer = None

def update_overlay(self, name, scalars=None, colormap=None,
opacity=None):
opacity=None, rng=None):
overlay = self._overlays.get(name, None)
if overlay is None:
return
Expand All @@ -203,6 +211,8 @@ def update_overlay(self, name, scalars=None, colormap=None,
overlay._colormap = colormap
if opacity is not None:
overlay._opacity = opacity
if rng is not None:
overlay._rng = rng
self.update()


Expand Down Expand Up @@ -423,6 +433,7 @@ def __init__(self, subject_id, hemi, surf, title=None,
self._annots = {'lh': list(), 'rh': list()}
self._layered_meshes = {}
self._elevation_rng = [15, 165] # range of motion of camera on theta
self._lut_locked = None
# default values for silhouette
self._silhouette = {
'color': self._bg_color,
Expand Down Expand Up @@ -777,17 +788,10 @@ def toggle_interface(self, value=None):
def apply_auto_scaling(self):
"""Detect automatically fitting scaling parameters."""
self._update_auto_scaling()
for key in self.keys:
self.widgets[key].set_value(self._data[key])
self._update()

def restore_user_scaling(self):
"""Restore original scaling parameters."""
self._update_auto_scaling(restore=True)
for key in self.keys:
self.widgets[key].set_value(self._data[key])
self.widgets[f"entry_{key}"].set_value(self._data[key])
self._update()

def toggle_playback(self, value=None):
"""Toggle time playback.
Expand Down Expand Up @@ -1043,13 +1047,11 @@ def _configure_dock_colormap_widget(self, name):
align=True,
layout=layout,
)
for idx, key in enumerate(self.keys):
up = UpdateLUT(brain=self)
for key in self.keys:
hlayout = self._renderer._dock_add_layout(vertical=False)
rng = _get_range(self)
self.callbacks[key] = BumpColorbarPoints(
brain=self,
name=key
)
self.callbacks[key] = lambda value, key=key: up(**{key: value})
self.widgets[key] = Widget(
widget=self._renderer._dock_add_slider(
name=None,
Expand All @@ -1071,6 +1073,7 @@ def _configure_dock_colormap_widget(self, name):
),
notebook=self.notebook,
)
up.widgets[key] = [self.widgets[key], self.widgets[f"entry_{key}"]]
if self.notebook:
from ..backends._notebook import _ipy_add_widget
_ipy_add_widget(layout, hlayout, self._renderer.dock_width)
Expand Down Expand Up @@ -2106,12 +2109,7 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None,
self._renderer.set_camera(**views_dicts[hemi][v])

# 4) update the scalar bar and opacity
self.update_lut()
if hemi in self._layered_meshes:
mesh = self._layered_meshes[hemi]
mesh.update_overlay(name='data', opacity=alpha)

self._update()
self.update_lut(alpha=alpha)

def _iter_views(self, hemi):
# which rows and columns each type of visual needs to be added to
Expand Down Expand Up @@ -2400,7 +2398,7 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None,
mesh.add_overlay(
scalars=scalars,
colormap=ctable,
rng=None,
rng=[np.min(scalars), np.max(scalars)],
opacity=alpha,
name=label_name,
)
Expand Down Expand Up @@ -2766,25 +2764,37 @@ def screenshot(self, mode='rgb', time_viewer=False):
[img, trace_img], bgcolor=self._brain_color[:3])
return img

@contextlib.contextmanager
def _no_lut_update(self, why):
orig = self._lut_locked
self._lut_locked = why
try:
yield
finally:
self._lut_locked = orig

@fill_doc
def update_lut(self, fmin=None, fmid=None, fmax=None):
def update_lut(self, fmin=None, fmid=None, fmax=None, alpha=None):
"""Update color map.

Parameters
----------
%(fmin_fmid_fmax)s
alpha : float | None
Alpha to use in the update.
"""
args = f'{fmin}, {fmid}, {fmax}, {alpha}'
if self._lut_locked is not None:
logger.debug(f'LUT update postponed with {args}')
return
logger.debug(f'Updating LUT with {args}')
center = self._data['center']
colormap = self._data['colormap']
transparent = self._data['transparent']
lims = dict(fmin=fmin, fmid=fmid, fmax=fmax)
lims = {key: self._data[key] if val is None else val
for key, val in lims.items()}
lims = {key: self._data[key] for key in ('fmin', 'fmid', 'fmax')}
_update_monotonic(lims, fmin=fmin, fmid=fmid, fmax=fmax)
assert all(val is not None for val in lims.values())
if lims['fmin'] > lims['fmid']:
lims['fmin'] = lims['fmid']
if lims['fmax'] < lims['fmid']:
lims['fmax'] = lims['fmid']

self._data.update(lims)
self._data['ctable'] = np.round(
calculate_lut(colormap, alpha=1., center=center,
Expand All @@ -2802,7 +2812,9 @@ def update_lut(self, fmin=None, fmid=None, fmax=None):
if hemi in self._layered_meshes:
mesh = self._layered_meshes[hemi]
mesh.update_overlay(name='data',
colormap=self._data['ctable'])
colormap=self._data['ctable'],
opacity=alpha,
rng=rng)
self._renderer._set_colormap_range(
mesh._actor, ctable, scalar_bar, rng,
self._brain_color)
Expand All @@ -2823,6 +2835,10 @@ def update_lut(self, fmin=None, fmid=None, fmax=None):
self._renderer._set_colormap_range(
glyph_actor_, ctable, scalar_bar, rng)
scalar_bar = None
if self.time_viewer:
with self._no_lut_update(f'update_lut {args}'):
for key in ('fmin', 'fmid', 'fmax'):
self.callbacks[key](lims[key])
self._update()

def set_data_smoothing(self, n_steps):
Expand Down Expand Up @@ -3468,6 +3484,36 @@ def _update_limits(fmin, fmid, fmax, center, array):
return fmin, fmid, fmax


def _update_monotonic(lims, fmin, fmid, fmax):
if fmin is not None:
lims['fmin'] = fmin
if lims['fmax'] < fmin:
logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmin}')
lims['fmax'] = fmin
if lims['fmid'] < fmin:
logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmin}')
lims['fmid'] = fmin
assert lims['fmin'] <= lims['fmid'] <= lims['fmax']
if fmid is not None:
lims['fmid'] = fmid
if lims['fmin'] > fmid:
logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmid}')
lims['fmin'] = fmid
if lims['fmax'] < fmid:
logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmid}')
lims['fmax'] = fmid
assert lims['fmin'] <= lims['fmid'] <= lims['fmax']
if fmax is not None:
lims['fmax'] = fmax
if lims['fmin'] > fmax:
logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmax}')
lims['fmin'] = fmax
if lims['fmid'] > fmax:
logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmax}')
lims['fmid'] = fmax
assert lims['fmin'] <= lims['fmid'] <= lims['fmax']


def _get_range(brain):
val = np.abs(np.concatenate(list(brain._current_act_data.values())))
return [np.min(val), np.max(val)]
Expand Down
55 changes: 13 additions & 42 deletions mne/viz/_brain/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Guillaume Favelier <guillaume.favelier@gmail.com>
#
# License: Simplified BSD
import time
from ...utils import logger


class Widget(object):
Expand Down Expand Up @@ -82,51 +82,22 @@ def __call__(self):
self.widgets[key].set_value(self.brain._data[key])


class BumpColorbarPoints(object):
"""Class that ensure constraints over the colorbar points."""
class UpdateLUT(object):
"""Update the LUT."""

def __init__(self, brain=None, name=None):
def __init__(self, brain=None):
self.brain = brain
self.name = name
self.callback = {
"fmin": lambda fmin: brain.update_lut(fmin=fmin),
"fmid": lambda fmid: brain.update_lut(fmid=fmid),
"fmax": lambda fmax: brain.update_lut(fmax=fmax),
}
self.widgets = {key: None for key in self.brain.keys}
self.last_update = time.time()
self.widgets = {key: list() for key in self.brain.keys}

def __call__(self, value):
def __call__(self, fmin=None, fmid=None, fmax=None):
"""Update the colorbar sliders."""
vals = {key: self.brain._data[key] for key in self.brain.keys}
if self.name == "fmin" and self.widgets["fmin"] is not None:
if vals['fmax'] < value:
vals['fmax'] = value
self.widgets['fmax'].set_value(value)
if vals['fmid'] < value:
vals['fmid'] = value
self.widgets['fmid'].set_value(value)
self.widgets['fmin'].set_value(value)
elif self.name == "fmid" and self.widgets['fmid'] is not None:
if vals['fmin'] > value:
vals['fmin'] = value
self.widgets['fmin'].set_value(value)
if vals['fmax'] < value:
vals['fmax'] = value
self.widgets['fmax'].set_value(value)
self.widgets['fmid'].set_value(value)
elif self.name == "fmax" and self.widgets['fmax'] is not None:
if vals['fmin'] > value:
vals['fmin'] = value
self.widgets['fmin'].set_value(value)
if vals['fmid'] > value:
vals['fmid'] = value
self.widgets['fmid'].set_value(value)
self.widgets['fmax'].set_value(value)
self.brain.widgets[f'entry_{self.name}'].set_value(value)
if time.time() > self.last_update + 1. / 60.:
self.callback[self.name](value)
self.last_update = time.time()
self.brain.update_lut(fmin=fmin, fmid=fmid, fmax=fmax)
with self.brain._no_lut_update(f'UpdateLUT {fmin} {fmid} {fmax}'):
for key in ('fmin', 'fmid', 'fmax'):
value = self.brain._data[key]
logger.debug(f'Updating {key} = {value}')
for widget in self.widgets[key]:
widget.set_value(value)


class ShowView(object):
Expand Down
Loading