Skip to content
230 changes: 120 additions & 110 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def __init__(self, subject_id, hemi, surf, title=None,
# for now only one color bar can be added
# since it is the same for all figures
self._colorbar_added = False
# for now only one time label can be added
# since it is the same for all figures
self._time_label_added = False
# array of data used by TimeViewer
self._data = {}
self.geo, self._hemi_meshes, self._overlays = {}, {}, {}
Expand Down Expand Up @@ -395,8 +398,11 @@ def time_label(x):
self._data['time_idx'] = time_idx
self._data['transparent'] = transparent
# data specific for a hemi
self._data[hemi + '_array'] = array
self._data[hemi + '_vertices'] = vertices
self._data[hemi] = dict()
self._data[hemi]['actor'] = list()
self._data[hemi]['mesh'] = list()
self._data[hemi]['array'] = array
self._data[hemi]['vertices'] = vertices

self._data['alpha'] = alpha
self._data['colormap'] = colormap
Expand All @@ -416,7 +422,7 @@ def time_label(x):
adj_mat,
smoothing_steps)
act_data = smooth_mat.dot(act_data)
self._data[hemi + '_smooth_mat'] = smooth_mat
self._data[hemi]['smooth_mat'] = smooth_mat

dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
Expand Down Expand Up @@ -445,16 +451,18 @@ def time_label(x):
actor, mesh = mesh_data
else:
actor, mesh = mesh_data, None
self._data[hemi + '_actor'] = actor
self._data[hemi + '_mesh'] = mesh
self._data[hemi]['actor'].append(actor)
self._data[hemi]['mesh'].append(mesh)
if array.ndim >= 2 and callable(time_label):
time_actor = self._renderer.text2d(
x_window=0.95, y_window=y_txt,
size=time_label_size,
text=time_label(time[time_idx]),
justification='right'
)
self._data[hemi + '_time_actor'] = time_actor
if not self._time_label_added:
time_actor = self._renderer.text2d(
x_window=0.95, y_window=y_txt,
size=time_label_size,
text=time_label(time[time_idx]),
justification='right'
)
self._data['time_actor'] = time_actor
self._time_label_added = True
if colorbar and not self._colorbar_added:
self._renderer.scalarbar(source=actor, n_labels=8,
bgcolor=(0.5, 0.5, 0.5))
Expand Down Expand Up @@ -738,14 +746,17 @@ def close(self):
"""Close all figures and cleanup data structure."""
self._renderer.close()

def show_view(self, view=None, roll=None, distance=None):
def show_view(self, view=None, roll=None, distance=None, row=0, col=0,
hemi=None):
"""Orient camera to display view."""
views_dict = lh_views_dict if self._hemi == 'lh' else rh_views_dict
hemi = self._hemi if hemi is None else hemi
views_dict = lh_views_dict if hemi == 'lh' else rh_views_dict
if isinstance(view, str):
view = views_dict.get(view)
elif isinstance(view, dict):
view = View(azim=view['azimuth'],
elev=view['elevation'])
self._renderer.subplot(row, col)
self._renderer.set_camera(azimuth=view.azim,
elevation=view.elev)

Expand Down Expand Up @@ -817,152 +828,151 @@ def set_data_smoothing(self, n_steps):
"""
from ..backends._pyvista import _set_mesh_scalars
from scipy.interpolate import interp1d
time_idx = self._data['time_idx']
for hemi in ['lh', 'rh']:
pd = self._data.get(hemi + '_mesh')
if pd is not None:
array = self._data[hemi + '_array']
vertices = self._data[hemi + '_vertices']
if pd is not None:
time_idx = self._data['time_idx']
act_data = array
if self._data['array'].ndim == 2:
hemi_data = self._data.get(hemi)
if hemi_data is not None:
array = hemi_data['array']
vertices = hemi_data['vertices']
for mesh in hemi_data['mesh']:
if array.ndim == 2:
if isinstance(time_idx, int):
act_data = act_data[:, time_idx]
act_data = array[:, time_idx]
else:
times = np.arange(self._n_times)
act_data = interp1d(
times, act_data, 'linear', axis=1,
times, array, 'linear', axis=1,
assume_sorted=True)(time_idx)

adj_mat = mesh_edges(self.geo[hemi].faces)
smooth_mat = smoothing_matrix(vertices,
adj_mat, int(n_steps),
verbose=False)
act_data = smooth_mat.dot(act_data)
_set_mesh_scalars(pd, act_data, 'Data')
self._data[hemi + '_smooth_mat'] = smooth_mat
_set_mesh_scalars(mesh, act_data, 'Data')
self._data[hemi]['smooth_mat'] = smooth_mat

def set_time_point(self, time_idx):
"""Set the time point shown."""
from ..backends._pyvista import _set_mesh_scalars
from scipy.interpolate import interp1d
time = self._data['time']
time_label = self._data['time_label']
time_actor = self._data.get('time_actor')
for hemi in ['lh', 'rh']:
pd = self._data.get(hemi + '_mesh')
if pd is not None:
array = self._data[hemi + '_array']
time = self._data['time']
time_label = self._data['time_label']
time_actor = self._data.get(hemi + '_time_actor')
if array.ndim == 1:
continue # skip data without time axis
# interpolation
if array.ndim == 2:
act_data = array

if isinstance(time_idx, int):
act_data = act_data[:, time_idx]
else:
times = np.arange(self._n_times)
act_data = interp1d(times, act_data, 'linear', axis=1,
assume_sorted=True)(time_idx)

smooth_mat = self._data[hemi + '_smooth_mat']
if smooth_mat is not None:
act_data = smooth_mat.dot(act_data)
_set_mesh_scalars(pd, act_data, 'Data')
if callable(time_label) and time_actor is not None:
if isinstance(time_idx, int):
self._current_time = time[time_idx]
time_actor.SetInput(time_label(self._current_time))
else:
ifunc = interp1d(times, self._data['time'])
self._current_time = ifunc(time_idx)
time_actor.SetInput(time_label(self._current_time))
self._data['time_idx'] = time_idx
hemi_data = self._data.get(hemi)
if hemi_data is not None:
array = hemi_data['array']
for mesh in hemi_data['mesh']:
# interpolation
if array.ndim == 2:
if isinstance(time_idx, int):
act_data = array[:, time_idx]
else:
times = np.arange(self._n_times)
act_data = interp1d(times, array, 'linear', axis=1,
assume_sorted=True)(time_idx)

smooth_mat = hemi_data['smooth_mat']
if smooth_mat is not None:
act_data = smooth_mat.dot(act_data)
_set_mesh_scalars(mesh, act_data, 'Data')
if callable(time_label) and time_actor is not None:
if isinstance(time_idx, int):
self._current_time = time[time_idx]
time_actor.SetInput(time_label(self._current_time))
else:
ifunc = interp1d(times, self._data['time'])
self._current_time = ifunc(time_idx)
time_actor.SetInput(time_label(self._current_time))
self._data['time_idx'] = time_idx

def update_fmax(self, fmax):
"""Set the colorbar max point."""
from ..backends._pyvista import _set_colormap_range
ctable = self.update_lut(fmax=fmax)
ctable = (ctable * 255).astype(np.uint8)
center = self._data['center']
fmin = self._data['fmin']
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmin = self._data['fmin']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmax'] = fmax
self._data['ctable'] = ctable
hemi_data = self._data.get(hemi)
if hemi_data is not None:
for actor in hemi_data['actor']:
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmax'] = fmax
self._data['ctable'] = ctable

def update_fmid(self, fmid):
"""Set the colorbar mid point."""
from ..backends._pyvista import _set_colormap_range
ctable = self.update_lut(fmid=fmid)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar)
self._data['fmid'] = fmid
self._data['ctable'] = ctable
hemi_data = self._data.get(hemi)
if hemi_data is not None:
for actor in hemi_data['actor']:
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar)
self._data['fmid'] = fmid
self._data['ctable'] = ctable

def update_fmin(self, fmin):
"""Set the colorbar min point."""
from ..backends._pyvista import _set_colormap_range
ctable = self.update_lut(fmin=fmin)
ctable = (ctable * 255).astype(np.uint8)
center = self._data['center']
fmax = self._data['fmax']
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
fmax = self._data['fmax']
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmin'] = fmin
self._data['ctable'] = ctable
hemi_data = self._data.get(hemi)
if hemi_data is not None:
for actor in hemi_data['actor']:
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['fmin'] = fmin
self._data['ctable'] = ctable

def update_fscale(self, fscale):
"""Scale the colorbar points."""
from ..backends._pyvista import _set_colormap_range
center = self._data['center']
fmin = self._data['fmin'] * fscale
fmid = self._data['fmid'] * fscale
fmax = self._data['fmax'] * fscale
ctable = self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax)
ctable = (ctable * 255).astype(np.uint8)
for hemi in ['lh', 'rh']:
actor = self._data.get(hemi + '_actor')
if actor is not None:
center = self._data['center']
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['ctable'] = ctable
self._data['fmin'] = fmin
self._data['fmid'] = fmid
self._data['fmax'] = fmax
hemi_data = self._data.get(hemi)
if hemi_data is not None:
for actor in hemi_data['actor']:
dt_max = fmax
dt_min = fmin if center is None else -1 * fmax
rng = [dt_min, dt_max]
if self._colorbar_added:
scalar_bar = self._renderer.plotter.scalar_bar
else:
scalar_bar = None
_set_colormap_range(actor, ctable, scalar_bar, rng)
self._data['ctable'] = ctable
self._data['fmin'] = fmin
self._data['fmid'] = fmid
self._data['fmax'] = fmax

@property
def data(self):
Expand Down
Loading