Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 58 additions & 57 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,14 +793,18 @@ def update_lut(self, fmin=None, fmid=None, fmax=None):
center = self._data['center']
colormap = self._data['colormap']
transparent = self._data['transparent']
fmin = self._data['fmin'] if fmin is None else fmin
fmid = self._data['fmid'] if fmid is None else fmid
fmax = self._data['fmax'] if fmax is None else fmax

lims = dict(fmin=fmin, fmid=fmid, fmax=fmax)
lims = {key: self._data[key] if val is None else val
for key, val in lims.items()}
assert all(val is not None for val in lims.values())
if lims['fmin'] > lims['fmid']:
lims['fmin'] = lims['fmid']
if lims['fmax'] < lims['fmid']:
lims['fmax'] = lims['fmid']
self._data.update(lims)
self._data['ctable'] = \
calculate_lut(colormap, alpha=alpha, fmin=fmin, fmid=fmid,
fmax=fmax, center=center, transparent=transparent)

calculate_lut(colormap, alpha=alpha, center=center,
transparent=transparent, **lims)
return self._data['ctable']

def set_data_smoothing(self, n_steps):
Expand Down Expand Up @@ -863,64 +867,61 @@ def set_time_point(self, time_idx):
def update_fmax(self, fmax):
"""Set the colorbar max point."""
from ..backends._pyvista import _set_colormap_range
if fmax > self._data['fmid']:
ctable = self.update_lut(fmax=fmax)
ctable = (ctable * 255).astype(np.uint8)
center = self._data['center']
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmin = self._data['fmin']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmax'] = fmax
self._data['ctable'] = ctable
ctable = self.update_lut(fmax=fmax)
ctable = (ctable * 255).astype(np.uint8)
center = self._data['center']
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmin = self._data['fmin']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmax'] = fmax
self._data['ctable'] = ctable

def update_fmid(self, fmid):
"""Set the colorbar mid point."""
from ..backends._pyvista import _set_colormap_range
if self._data['fmin'] < fmid < self._data['fmax']:
ctable = self.update_lut(fmid=fmid)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar)
self._data['fmid'] = fmid
self._data['ctable'] = ctable
ctable = self.update_lut(fmid=fmid)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar)
self._data['fmid'] = fmid
self._data['ctable'] = ctable

def update_fmin(self, fmin):
"""Set the colorbar min point."""
from ..backends._pyvista import _set_colormap_range
if fmin < self._data['fmid']:
ctable = self.update_lut(fmin=fmin)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmax = self._data['fmax']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmin'] = fmin
self._data['ctable'] = ctable
ctable = self.update_lut(fmin=fmin)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmax = self._data['fmax']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmin'] = fmin
self._data['ctable'] = ctable

def update_fscale(self, fscale):
"""Scale the colorbar points."""
Expand Down
106 changes: 90 additions & 16 deletions mne/viz/_brain/_timeviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#
# License: Simplified BSD

import time
import numpy as np


class IntSlider(object):
"""Class to set a integer slider."""
Expand All @@ -13,9 +16,9 @@ def __init__(self, plotter=None, callback=None, name=None):
self.callback = callback
self.name = name

def __call__(self, idx):
def __call__(self, value):
"""Round the label of the slider."""
idx = int(round(idx))
idx = int(round(value))
for slider in self.plotter.slider_widgets:
name = getattr(slider, "name", None)
if name == self.name:
Expand Down Expand Up @@ -50,6 +53,58 @@ def __call__(self, value):
slider_rep.SetValue(fmax)


class BumpColorbarPoints(object):
"""Class that ensure constraints over the colorbar points."""

def __init__(self, plotter=None, brain=None, name=None):
self.plotter = plotter
self.brain = brain
self.name = name
self.callback = {
"fmin": brain.update_fmin,
"fmid": brain.update_fmid,
"fmax": brain.update_fmax
}
self.last_update = time.time()

def __call__(self, value):
"""Update the colorbar sliders."""
keys = ('fmin', 'fmid', 'fmax')
vals = {key: self.brain._data[key] for key in keys}
reps = {key: None for key in keys}
for slider in self.plotter.slider_widgets:
name = getattr(slider, "name", None)
if name is not None:
reps[name] = slider.GetRepresentation()
if self.name == "fmin" and reps["fmin"] is not None:
if vals['fmax'] < value:
self.brain.update_fmax(value)
reps['fmax'].SetValue(value)
if vals['fmid'] < value:
self.brain.update_fmid(value)
reps['fmid'].SetValue(value)
reps['fmin'].SetValue(value)
elif self.name == "fmid" and reps['fmid'] is not None:
if vals['fmin'] > value:
self.brain.update_fmin(value)
reps['fmin'].SetValue(value)
if vals['fmax'] < value:
self.brain.update_fmax(value)
reps['fmax'].SetValue(value)
reps['fmid'].SetValue(value)
elif self.name == "fmax" and reps['fmax'] is not None:
if vals['fmin'] > value:
self.brain.update_fmin(value)
reps['fmin'].SetValue(value)
if vals['fmid'] > value:
self.brain.update_fmid(value)
reps['fmid'].SetValue(value)
reps['fmax'].SetValue(value)
if time.time() > self.last_update + 1. / 60.:
self.callback[self.name](value)
self.last_update = time.time()


class _TimeViewer(object):
"""Class to interact with _Brain."""

Expand All @@ -66,20 +121,20 @@ def __init__(self, brain):

# smoothing slider
default_smoothing_value = 7
set_smoothing = IntSlider(
self.set_smoothing = IntSlider(
plotter=self.plotter,
callback=brain.set_data_smoothing,
name="smoothing"
)
smoothing_slider = self.plotter.add_slider_widget(
set_smoothing,
self.set_smoothing,
value=default_smoothing_value,
rng=[1, 15], title="smoothing",
pointa=(0.82, 0.90),
pointb=(0.98, 0.90)
)
smoothing_slider.name = 'smoothing'
set_smoothing(default_smoothing_value)
self.set_smoothing(default_smoothing_value)

# orientation slider
orientation = [
Expand Down Expand Up @@ -122,30 +177,48 @@ def __init__(self, brain):
# colormap slider
scaling_limits = [0.2, 2.0]
fmin = brain._data["fmin"]
self.update_fmin = BumpColorbarPoints(
plotter=self.plotter,
brain=brain,
name="fmin"
)
fmin_slider = self.plotter.add_slider_widget(
brain.update_fmin,
self.update_fmin,
value=fmin,
rng=_get_range(fmin, scaling_limits), title="fmin",
rng=_get_range(brain), title="fmin",
pointa=(0.82, 0.26),
pointb=(0.98, 0.26)
pointb=(0.98, 0.26),
event_type="always",
)
fmin_slider.name = "fmin"
fmid = brain._data["fmid"]
self.update_fmid = BumpColorbarPoints(
plotter=self.plotter,
brain=brain,
name="fmid",
)
fmid_slider = self.plotter.add_slider_widget(
brain.update_fmid,
self.update_fmid,
value=fmid,
rng=_get_range(fmid, scaling_limits), title="fmid",
rng=_get_range(brain), title="fmid",
pointa=(0.82, 0.42),
pointb=(0.98, 0.42)
pointb=(0.98, 0.42),
event_type="always",
)
fmid_slider.name = "fmid"
fmax = brain._data["fmax"]
self.update_fmax = BumpColorbarPoints(
plotter=self.plotter,
brain=brain,
name="fmax",
)
fmax_slider = self.plotter.add_slider_widget(
brain.update_fmax,
self.update_fmax,
value=fmax,
rng=_get_range(fmax, scaling_limits), title="fmax",
rng=_get_range(brain), title="fmax",
pointa=(0.82, 0.58),
pointb=(0.98, 0.58)
pointb=(0.98, 0.58),
event_type="always",
)
fmax_slider.name = "fmax"
update_fscale = UpdateColorbarScale(
Expand Down Expand Up @@ -194,5 +267,6 @@ def _set_slider_style(slider, show_label=True):
slider_rep.ShowSliderLabelOff()


def _get_range(val, rng):
return [val * rng[0], val * rng[1]]
def _get_range(brain):
val = np.abs(brain._data['array'])
return [np.min(val), np.max(val)]
Loading