Skip to content

Commit

Permalink
MRG, MAINT: Simpler vector params (#291)
Browse files Browse the repository at this point in the history
* MAINT: Simpler vector params

* FIX: Undo auto scaling

* FIX: Dup

* FIX: URL

* FIX: Better

* FIX: More tolerant of type
  • Loading branch information
larsoner authored Jun 23, 2020
1 parent b0ca3a1 commit 9e5fe1f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
command: |
python -m pip install --user -q --upgrade pip numpy
python -m pip install --user -q --upgrade --progress-bar off scipy matplotlib vtk pyqt5 pyqt5-sip nibabel sphinx numpydoc pillow imageio imageio-ffmpeg sphinx-gallery
python -m pip install --user -q --upgrade mayavi "https://api.github.com/repos/mne-tools/mne-python/zipball/master"
python -m pip install --user -q --upgrade mayavi "https://github.com/mne-tools/mne-python/archive/master.zip"
- save_cache:
key: pip-cache
paths:
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ before_install:
pip install https://github.com/enthought/mayavi/zipball/master;
fi;
- mkdir -p $SUBJECTS_DIR
- pip install "https://api.github.com/repos/mne-tools/mne-python/zipball/master";
- pip install "https://github.com/mne-tools/mne-python/archive/master.zip"
- python -c "import mne; mne.datasets.fetch_fsaverage(verbose=True)"

install:
Expand Down
46 changes: 19 additions & 27 deletions surfer/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __init__(self, subject_id, hemi, surf, title=None,
title = subject_id
self.subject_id = subject_id

if not isinstance(views, list):
if not isinstance(views, (list, tuple)):
views = [views]
n_row = len(views)

Expand Down Expand Up @@ -1095,23 +1095,20 @@ def add_data(self, array, min=None, max=None, thresh=None,
smooth_mat = None

magnitude = None
magnitude_max = None
if array.ndim == 3:
if array.shape[1] != 3:
raise ValueError('If array has 3 dimensions, array.shape[1] '
'must equal 3, got %s' % (array.shape[1],))
magnitude = np.linalg.norm(array, axis=1)
if scale_factor is None:
distance = np.sum([array[:, dim, :].ptp(axis=0).max() ** 2
for dim in range(3)])
distance = 4 * np.linalg.norm(array, axis=1).max()
if distance == 0:
scale_factor = 1
else:
scale_factor = (0.4 * distance /
(4 * array.shape[0] ** (0.33)))
if self._units == 'm':
scale_factor = scale_factor / 1000.
magnitude_max = magnitude.max()
elif array.ndim not in (1, 2):
raise ValueError('array has must have 1, 2, or 3 dimensions, '
'got (%s)' % (array.ndim,))
Expand Down Expand Up @@ -1188,7 +1185,7 @@ def time_label(x):
if brain['hemi'] == hemi:
s, ct, bar, gl = brain['brain'].add_data(
array, min, mid, max, thresh, lut, colormap, alpha,
colorbar, layer_id, smooth_mat, magnitude, magnitude_max,
colorbar, layer_id, smooth_mat, magnitude,
scale_factor, vertices, vector_alpha, **kwargs)
surfs.append(s)
bars.append(bar)
Expand Down Expand Up @@ -2115,13 +2112,11 @@ def set_data_time_index(self, time_idx, interpolation='quadratic'):
if vectors is not None:
vectors = vectors[:, :, time_idx]

vector_values = scalar_data.copy()
if data['smooth_mat'] is not None:
scalar_data = data['smooth_mat'] * scalar_data
for brain in self.brains:
if brain.hemi == hemi:
brain.set_data(data['layer_id'], scalar_data,
vectors, vector_values)
brain.set_data(data['layer_id'], scalar_data, vectors)
del brain
data["time_idx"] = time_idx

Expand Down Expand Up @@ -3225,24 +3220,25 @@ def _remove_scalar_data(self, array_id):
self._mesh_clones.pop(array_id).remove()
self._mesh_dataset.point_data.remove_array(array_id)

def _add_vector_data(self, vectors, vector_values, fmin, fmid, fmax,
scale_factor_norm, vertices, vector_alpha, lut):
def _add_vector_data(self, vectors, fmin, fmid, fmax,
scale_factor, vertices, vector_alpha, lut):
vertices = slice(None) if vertices is None else vertices
x, y, z = np.array(self._geo_mesh.data.points.data)[vertices].T
vector_alpha = min(vector_alpha, 0.9999999)
with warnings.catch_warnings(record=True): # HasTraits
quiver = mlab.quiver3d(
x, y, z, vectors[:, 0], vectors[:, 1], vectors[:, 2],
scalars=vector_values, colormap='hot', vmin=fmin,
colormap='hot', vmin=fmin, scale_mode='vector',
vmax=fmax, figure=self._f, opacity=vector_alpha)

# Enable backface culling
quiver.actor.property.backface_culling = True
quiver.mlab_source.update()

# Compute scaling for the glyphs
quiver.glyph.glyph.scale_factor = (scale_factor_norm *
vector_values.max())
# Set scaling for the glyphs
quiver.glyph.glyph.scale_factor = scale_factor
quiver.glyph.glyph.clamping = False
quiver.glyph.glyph.range = (0., 1.)

# Scale colormap used for the glyphs
l_m = quiver.parent.vector_lut_manager
Expand Down Expand Up @@ -3293,7 +3289,7 @@ def add_overlay(self, old, **kwargs):

@verbose
def add_data(self, array, fmin, fmid, fmax, thresh, lut, colormap, alpha,
colorbar, layer_id, smooth_mat, magnitude, magnitude_max,
colorbar, layer_id, smooth_mat, magnitude,
scale_factor, vertices, vector_alpha, **kwargs):
"""Add data to the brain"""
# Calculate initial data to plot
Expand All @@ -3308,24 +3304,20 @@ def add_data(self, array, fmin, fmid, fmax, thresh, lut, colormap, alpha,
array_plot = magnitude[:, 0]
else:
raise ValueError("data has to be 1D, 2D, or 3D")
vector_values = array_plot
if smooth_mat is not None:
array_plot = smooth_mat * array_plot

# Copy and byteswap to deal with Mayavi bug
array_plot = _prepare_data(array_plot)

array_id, pipe = self._add_scalar_data(array_plot)
scale_factor_norm = None
if array.ndim == 3:
scale_factor_norm = scale_factor / magnitude_max
vectors = array[:, :, 0].copy()
glyphs = self._add_vector_data(
vectors, vector_values, fmin, fmid, fmax,
scale_factor_norm, vertices, vector_alpha, lut)
vectors, fmin, fmid, fmax,
scale_factor, vertices, vector_alpha, lut)
else:
glyphs = None
del scale_factor
mesh = pipe.parent
if thresh is not None:
if array_plot.min() >= thresh:
Expand Down Expand Up @@ -3364,7 +3356,7 @@ def add_data(self, array, fmin, fmid, fmax, thresh, lut, colormap, alpha,

self.data[layer_id] = dict(
array_id=array_id, mesh=mesh, glyphs=glyphs,
scale_factor_norm=scale_factor_norm)
scale_factor=scale_factor)
return surf, orig_ctable, bar, glyphs

def add_annotation(self, annot, ids, cmap, **kwargs):
Expand Down Expand Up @@ -3475,7 +3467,7 @@ def remove_data(self, layer_id):
self._remove_scalar_data(data['array_id'])
self._remove_vector_data(data['glyphs'])

def set_data(self, layer_id, values, vectors=None, vector_values=None):
def set_data(self, layer_id, values, vectors=None):
"""Set displayed data values and vectors."""
data = self.data[layer_id]
self._mesh_dataset.point_data.get_array(
Expand All @@ -3492,12 +3484,12 @@ def set_data(self, layer_id, values, vectors=None, vector_values=None):

# Update glyphs
q.mlab_source.vectors = vectors
q.mlab_source.scalars = vector_values
q.mlab_source.update()

# Update changed parameters, and glyph scaling
q.glyph.glyph.scale_factor = (data['scale_factor_norm'] *
values.max())
q.glyph.glyph.scale_factor = data['scale_factor']
q.glyph.glyph.range = (0., 1.)
q.glyph.glyph.clamping = False
l_m.load_lut_from_list(lut / 255.)
l_m.data_range = data_range

Expand Down

0 comments on commit 9e5fe1f

Please sign in to comment.