Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix bug with regress_artifact picking #12389

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12389.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :func:`mne.preprocessing.regress_artifact` projection check was not specific to the channels being processed, by `Eric Larson`_.
12 changes: 11 additions & 1 deletion mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,8 @@ def pick_info(info, sel=(), copy=True, verbose=None):
return info
elif len(sel) == 0:
raise ValueError("No channels match the selection.")
n_unique = len(np.unique(np.arange(len(info["ch_names"]))[sel]))
ch_set = set(info["ch_names"][k] for k in sel)
n_unique = len(ch_set)
if n_unique != len(sel):
raise ValueError(
"Found %d / %d unique names, sel is not unique" % (n_unique, len(sel))
Expand Down Expand Up @@ -687,6 +688,15 @@ def pick_info(info, sel=(), copy=True, verbose=None):
if info.get("custom_ref_applied", False) and not _electrode_types(info):
with info._unlock():
info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF
# remove unused projectors
if info.get("projs", []):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stumbled over this line, because seeing the [] in foo.get("bar", []) made me expect that the result was being assigned or iterated over. Here we're only (effectively) passing it to bool() for the duration of this one line; so I think it's clearer to default to False rather than a (falsey) empty list.

Suggested change
if info.get("projs", []):
if info.get("projs", False):

projs = list()
for p in info["projs"]:
if any(ch_name in ch_set for ch_name in p["data"]["col_names"]):
projs.append(p)
if len(projs) != len(info["projs"]):
with info._unlock():
info["projs"] = projs
info._check_consistency()

return info
Expand Down
6 changes: 6 additions & 0 deletions mne/_fiff/tests/test_pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,17 @@ def test_clean_info_bads():
# simulate the bad channels
raw.info["bads"] = eeg_bad_ch + meg_bad_ch

assert len(raw.info["projs"]) == 3
raw.set_eeg_reference(projection=True)
assert len(raw.info["projs"]) == 4

# simulate the call to pick_info excluding the bad eeg channels
info_eeg = pick_info(raw.info, picks_eeg)
assert len(info_eeg["projs"]) == 1

# simulate the call to pick_info excluding the bad meg channels
info_meg = pick_info(raw.info, picks_meg)
assert len(info_meg["projs"]) == 3

assert info_eeg["bads"] == eeg_bad_ch
assert info_meg["bads"] == meg_bad_ch
Expand Down
14 changes: 9 additions & 5 deletions mne/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@
# To update the `testing` or `misc` datasets, push or merge commits to their
# respective repos, and make a new release of the dataset on GitHub. Then
# update the checksum in the MNE_DATASETS dict below, and change version
# here: ↓↓↓↓↓ ↓↓↓
RELEASES = dict(testing="0.151", misc="0.27")
# here: ↓↓↓↓↓↓↓↓
RELEASES = dict(
testing="0.151",
misc="0.27",
phantom_kit="0.2",
)
TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}'
MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}'

Expand Down Expand Up @@ -176,9 +180,9 @@
)

MNE_DATASETS["phantom_kit"] = dict(
archive_name="MNE-phantom-KIT-24bit.zip",
hash="md5:CAF82EE978DD473C7DE6C1034D9CCD45",
url="https://osf.io/download/svnt3/",
archive_name="MNE-phantom-KIT-data.tar.gz",
hash="md5:7bfdf40bbeaf17a66c99c695640e0740",
url="https://osf.io/fb6ya/download?version=1",
folder_name="MNE-phantom-KIT-data",
config_key="MNE_DATASETS_PHANTOM_KIT_PATH",
)
Expand Down
2 changes: 1 addition & 1 deletion mne/datasets/phantom_kit/phantom_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def data_path(
): # noqa: D103
return _download_mne_dataset(
name="phantom_kit",
processor="unzip",
processor="untar",
path=path,
force_update=force_update,
update_path=update_path,
Expand Down
24 changes: 13 additions & 11 deletions mne/preprocessing/_regress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .._fiff.pick import _picks_to_idx
from .._fiff.pick import _picks_to_idx, pick_info
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..epochs import BaseEpochs
from ..evoked import Evoked
Expand Down Expand Up @@ -178,9 +178,7 @@ def fit(self, inst):
reference (see :func:`mne.set_eeg_reference`) before performing EOG
regression.
"""
self._check_inst(inst)
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
picks, picks_artifact = self._check_inst(inst)

# Calculate regression coefficients. Add a row of ones to also fit the
# intercept.
Expand Down Expand Up @@ -232,9 +230,7 @@ def apply(self, inst, copy=True):
"""
if copy:
inst = inst.copy()
self._check_inst(inst)
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
picks, picks_artifact = self._check_inst(inst)

# Check that the channels are compatible with the regression weights.
ref_picks = _picks_to_idx(
Expand Down Expand Up @@ -324,19 +320,25 @@ def _check_inst(self, inst):
_validate_type(
inst, (BaseRaw, BaseEpochs, Evoked), "inst", "Raw, Epochs, Evoked"
)
if _needs_eeg_average_ref_proj(inst.info):
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
all_picks = np.unique(np.concatenate([picks, picks_artifact]))
use_info = pick_info(inst.info, all_picks)
del all_picks
if _needs_eeg_average_ref_proj(use_info):
raise RuntimeError(
"No reference for the EEG channels has been "
"set. Use inst.set_eeg_reference() to do so."
"No average reference for the EEG channels has been "
"set. Use inst.set_eeg_reference(projection=True) to do so."
)
if self.proj and not inst.proj:
inst.apply_proj()
if not inst.proj and len(inst.info.get("projs", [])) > 0:
if not inst.proj and len(use_info.get("projs", [])) > 0:
raise RuntimeError(
"Projections need to be applied before "
"regression can be performed. Use the "
".apply_proj() method to do so."
)
return picks, picks_artifact

def __repr__(self):
"""Produce a string representation of this object."""
Expand Down
13 changes: 13 additions & 0 deletions mne/preprocessing/tests/test_regress.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def test_regress_artifact():
epochs, betas = regress_artifact(epochs, picks="eog", picks_artifact="eog")
assert np.ptp(epochs.get_data("eog")) < 1e-15 # constant value
assert_allclose(betas, 1)
# proj should only be required of channels being processed
raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
raw.del_proj()
raw.set_eeg_reference(projection=True)
model = EOGRegression(proj=False, picks="meg", picks_artifact="eog")
model.fit(raw)
model.apply(raw)
model = EOGRegression(proj=False, picks="eeg", picks_artifact="eog")
with pytest.raises(RuntimeError, match="Projections need to be applied"):
model.fit(raw)
raw.del_proj()
with pytest.raises(RuntimeError, match="No average reference for the EEG"):
model.fit(raw)


@testing.requires_testing_data
Expand Down
122 changes: 30 additions & 92 deletions tutorials/inverse/95_phantom_KIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,42 @@
# Copyright the MNE-Python contributors.

# %%
import matplotlib.pyplot as plt
import mne_bids
import numpy as np
from scipy.signal import find_peaks

import mne

data_path = mne.datasets.phantom_kit.data_path()
actual_pos, actual_ori = mne.dipole.get_phantom_dipoles("oyama")
actual_pos, actual_ori = actual_pos[:49], actual_ori[:49] # only 49 of 50 dipoles

raw = mne.io.read_raw_kit(data_path / "002_phantom_11Hz_100uA.con")
# cut from ~800 to ~300s for speed, and also at convenient dip stim boundaries
# chosen by examining MISC 017 by eye.
raw.crop(11.5, 302.9).load_data()
raw.filter(None, 40) # 11 Hz stimulation, no need to keep higher freqs
bids_path = mne_bids.BIDSPath(
root=data_path,
subject="01",
task="phantom",
run="01",
datatype="meg",
)
# ignore warning about misc units
raw = mne_bids.read_raw_bids(bids_path).load_data()

# Let's apply a little bit of preprocessing (temporal filtering and reference
# regression)
picks_artifact = ["MISC 001", "MISC 002", "MISC 003"]
picks = np.r_[
mne.pick_types(raw.info, meg=True),
mne.pick_channels(raw.info["ch_names"], picks_artifact),
]
raw.filter(None, 40, picks=picks)
mne.preprocessing.regress_artifact(
raw, picks="meg", picks_artifact=picks_artifact, copy=False, proj=False
)
plot_scalings = dict(mag=5e-12) # large-amplitude sinusoids
raw_plot_kwargs = dict(duration=15, n_channels=50, scalings=plot_scalings)
raw.plot(**raw_plot_kwargs)
events, event_id = mne.events_from_annotations(raw)
raw.plot(events=events, **raw_plot_kwargs)
n_dip = len(event_id)
assert n_dip == 49 # sanity check

# %%
# We can also look at the power spectral density to see the phantom oscillations at
Expand All @@ -45,82 +63,12 @@
dip_freq = 11.0
fig.axes[0].axvline(dip_freq, color="r", ls="--", lw=2, zorder=4)

# %%
# To find the events, we can look at the MISC channel that recorded the activations.
# Here we use a very simple thresholding approach to find the events.
# The MISC 017 channel holds the dipole activations, which are 2-cycle 11 Hz sinusoidal
# bursts with the initial sinusoidal deflection downward, so we do a little bit of
# signal manipulation to help :func:`~scipy.signal.find_peaks`.

# Figure out events
dip_act, dip_t = raw["MISC 017"]
dip_act = dip_act[0] # 2D to 1D array
dip_act -= dip_act.mean() # remove DC offset
dip_act *= -1 # invert so first deflection is positive
thresh = np.percentile(dip_act, 90)
min_dist = raw.info["sfreq"] / dip_freq * 0.9 # 90% of period, to be safe
peaks = find_peaks(dip_act, height=thresh, distance=min_dist)[0]
assert len(peaks) % 2 == 0 # 2-cycle modulations
peaks = peaks[::2] # take only first peaks of each 2-cycle burst

fig, ax = plt.subplots(layout="constrained", figsize=(12, 4))
stop = int(15 * raw.info["sfreq"]) # 15 sec
ax.plot(dip_t[:stop], dip_act[:stop], color="k", lw=1)
ax.axhline(thresh, color="r", ls="--", lw=1)
peak_idx = peaks[peaks < stop]
ax.plot(dip_t[peak_idx], dip_act[peak_idx], "ro", zorder=5, ms=5)
ax.set(xlabel="Time (s)", ylabel="Dipole activation (AU)\n(MISC 017 adjusted)")
ax.set(xlim=dip_t[[0, stop - 1]])

# We know that there are 32 dipoles, so mark the first ones as well
n_dip = 49
assert len(peaks) % n_dip == 0 # we found them all (hopefully)
ax.plot(dip_t[peak_idx[::n_dip]], dip_act[peak_idx[::n_dip]], "bo", zorder=4, ms=10)

# Knowing we've caught the top of the first cycle of a 11 Hz sinusoid, plot onsets
# with red X's.
onsets = peaks - np.round(raw.info["sfreq"] / dip_freq / 4.0).astype(
int
) # shift to start
onset_idx = onsets[onsets < stop]
ax.plot(dip_t[onset_idx], dip_act[onset_idx], "rx", zorder=5, ms=5)

# %%
# Given the onsets are now stored in ``peaks``, we can create our events array and plot
# on our raw data.

n_rep = len(peaks) // n_dip
events = np.zeros((len(peaks), 3), int)
events[:, 0] = onsets + raw.first_samp
events[:, 2] = np.tile(np.arange(1, n_dip + 1), n_rep)
raw.plot(events=events, **raw_plot_kwargs)

# %%
# Now we can figure out our epoching parameters and epoch the data, sanity checking
# some values along the way knowing how the stimulation was done.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this text may need to be updated? Looks like the "sanity checking" is all deleted now


# Sanity check and determine epoching params
deltas = np.diff(events[:, 0], axis=0)
group_deltas = deltas[n_dip - 1 :: n_dip] / raw.info["sfreq"] # gap between 49 and 1
assert (group_deltas > 0.8).all()
assert (group_deltas < 0.9).all()
others = np.delete(deltas, np.arange(n_dip - 1, len(deltas), n_dip)) # remove 49->1
others = others / raw.info["sfreq"]
assert (others > 0.25).all()
assert (others < 0.3).all()
tmax = 1 / dip_freq * 2.0 # 2 cycles
tmin = tmax - others.min()
assert tmin < 0
epochs = mne.Epochs(
raw,
events,
tmin=tmin,
tmax=tmax,
baseline=(None, 0),
decim=10,
picks="data",
preload=True,
)
tmin, tmax = -0.08, 0.18
epochs = mne.Epochs(raw, tmin=tmin, tmax=tmax, decim=10, picks="data", preload=True)
del raw
epochs.plot(scalings=plot_scalings)

Expand All @@ -131,7 +79,7 @@
t_peak = 1.0 / dip_freq / 4.0
data = np.zeros((len(epochs.ch_names), n_dip))
for di in range(n_dip):
data[:, [di]] = epochs[str(di + 1)].average().crop(t_peak, t_peak).data
data[:, [di]] = epochs[f"dip{di + 1:02d}"].average().crop(t_peak, t_peak).data
evoked = mne.EvokedArray(data, epochs.info, tmin=0, comment="KIT phantom activations")
evoked.plot_joint()

Expand All @@ -141,22 +89,12 @@
trans = mne.transforms.Transform("head", "mri", np.eye(4))
sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.08)
cov = mne.compute_covariance(epochs, tmax=0, method="empirical")
# We need to correct the ``dev_head_t`` because it's incorrect for these data!
# relative to the helmet: hleft, forward, up
translation = mne.transforms.translation(x=0.01, y=-0.015, z=-0.088)
# pitch down (rot about x/R), roll left (rot about y/A), yaw left (rot about z/S)
rotation = mne.transforms.rotation(
x=np.deg2rad(5),
y=np.deg2rad(-1),
z=np.deg2rad(-3),
)
evoked.info["dev_head_t"]["trans"][:] = translation @ rotation
dip, residual = mne.fit_dipole(evoked, cov, sphere, n_jobs=None)

# %%
# Finally let's look at the results.

# sphinx_gallery_thumbnail_number = 7
# sphinx_gallery_thumbnail_number = 5

print(f"Average amplitude: {np.mean(dip.amplitude) * 1e9:0.1f} nAm")
print(f"Average GOF: {np.mean(dip.gof):0.1f}%")
Expand Down
Loading