Skip to content
This repository was archived by the owner on Dec 8, 2023. It is now read-only.

Added tests for NWB export #18

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def dj_config():
dj.config.load('./dj_local_conf.json')
dj.config['safemode'] = False
dj.config['custom'] = {
'ephys_mode': (os.environ.get('EPHYS_MODE')
or dj.config['custom']['ephys_mode']),
'database.prefix': (os.environ.get('DATABASE_PREFIX')
or dj.config['custom']['database.prefix']),
'ephys_root_data_dir': (os.environ.get('EPHYS_ROOT_DATA_DIR').split(',') if os.environ.get('EPHYS_ROOT_DATA_DIR') else dj.config['custom']['ephys_root_data_dir'])
Expand Down Expand Up @@ -112,7 +114,9 @@ def pipeline():
'ephys': pipeline.ephys,
'probe': pipeline.probe,
'session': pipeline.session,
'get_ephys_root_data_dir': pipeline.get_ephys_root_data_dir}
'get_ephys_root_data_dir': pipeline.get_ephys_root_data_dir,
'ephys_mode': pipeline.ephys_mode}

if verbose and _tear_down:
pipeline.subject.Subject.delete()
elif not verbose and _tear_down:
Expand Down Expand Up @@ -252,7 +256,10 @@ def kilosort_paramset(pipeline):

# Insert here, since most of the test will require this paramset inserted
ephys.ClusteringParamSet.insert_new_params(
'kilosort2', 0, 'Spike sorting using Kilosort2', params_ks)
clustering_method='kilosort2.5',
paramset_desc='Spike sorting using Kilosort2.5',
params=params_ks,
paramset_idx=0)

yield params_ks

Expand Down Expand Up @@ -294,9 +301,9 @@ def clustering_tasks(pipeline, kilosort_paramset, ephys_recordings):
kilosort_dir = next(recording_dir.rglob('spike_times.npy')).parent
ephys.ClusteringTask.insert1({**ephys_rec_key,
'paramset_idx': 0,
'clustering_output_dir':
kilosort_dir.as_posix()
}, skip_duplicates=True)
'task_mode': 'load',
'clustering_output_dir': kilosort_dir.as_posix()},
skip_duplicates=True)

yield

Expand Down Expand Up @@ -327,17 +334,21 @@ def clustering(clustering_tasks, pipeline):

@pytest.fixture
def curations(clustering, pipeline):
"""Insert keys from ephys.ClusteringTask into ephys.Curation"""
ephys = pipeline['ephys']
ephys_mode = pipeline['ephys_mode']

for key in (ephys.ClusteringTask - ephys.Curation).fetch('KEY'):
ephys.Curation().create1_from_clustering_task(key)
if ephys_mode == 'no-curation':
yield
else:
ephys = pipeline['ephys']

yield
for key in (ephys.ClusteringTask - ephys.Curation).fetch('KEY'):
ephys.Curation().create1_from_clustering_task(key)

if _tear_down:
if verbose:
ephys.Curation.delete()
else:
with QuietStdOut():
yield

if _tear_down:
if verbose:
ephys.Curation.delete()
else:
with QuietStdOut():
ephys.Curation.delete()
85 changes: 85 additions & 0 deletions tests/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys
import pathlib
import numpy as np

from . import (dj_config, pipeline, test_data,
subjects_csv, ingest_subjects,
sessions_csv, ingest_sessions,
testdata_paths, kilosort_paramset,
ephys_recordings, clustering_tasks, clustering, curations)


def test_subject_nwb_export(ingest_subjects, pipeline):
subject = pipeline['subject']
subject_key = {'subject': 'subject1'}
nwb_subject = subject.Subject.make_nwb(subject_key)

subject_info = (subject.Subject & subject_key).fetch1()

assert nwb_subject.subject_id == subject_info['subject']
assert nwb_subject.sex == subject_info['sex']
assert nwb_subject.date_of_birth.date() == subject_info['subject_birth_date']


def test_session_nwb_export(ingest_sessions, pipeline):
session = pipeline['session']
session_key = {'subject': 'subject1', 'session_datetime': '2018-11-22 18:51:26'}
nwb_session = session.Session.make_nwb(session_key)

session_info = (session.Session & session_key).fetch1()

assert nwb_session.session_start_time.strftime('%Y%m%d_%H%M%S') == session_info['session_datetime'].strftime('%Y%m%d_%H%M%S')
assert nwb_session.subject.subject_id == session_info['subject']
assert nwb_session.experimenter == list(session.SessionExperimenter.fetch('user'))


def test_ephys_nwb_export(curations, pipeline, testdata_paths):
ephys = pipeline['ephys']
probe = pipeline['probe']

rel_path = testdata_paths['npx3B-p1-ks']
curation_key = (ephys.Curation & f'curation_output_dir LIKE "%{rel_path}"').fetch1('KEY')
ephys.CuratedClustering.populate(curation_key)
ephys.LFP.populate(curation_key)
ephys.WaveformSet.populate(curation_key)

nwb_ephys = ephys.CuratedClustering.make_nwb(curation_key)

probe_name, probe_type = (ephys.ProbeInsertion * probe.Probe * probe.ProbeType
& curation_key).fetch1('probe', 'probe_type')

device_name = f'{probe_name} ({probe_type})'
assert device_name in nwb_ephys.devices

# check LFP
lfp_name = f'probe_{probe_name} - LFP'
assert lfp_name in nwb_ephys.processing['ecephys'].data_interfaces

lfp_timestamps = (ephys.LFP & curation_key).fetch1('lfp_time_stamps')
lfp_channel_count = len((ephys.LFP.Electrode & curation_key))

nwb_lfp = nwb_ephys.processing['ecephys'].data_interfaces[lfp_name].electrical_series['processed_electrical_series']
assert nwb_lfp.data.shape == (len(lfp_timestamps), lfp_channel_count)

# check electrodes
nwb_electrodes = nwb_ephys.electrodes.to_dataframe()
electrodes = (ephys.EphysRecording * probe.ElectrodeConfig.Electrode
* probe.ProbeType.Electrode & curation_key).fetch(
format='frame').reset_index()
assert np.array_equal(nwb_electrodes.index, electrodes.index)
assert np.array_equal(nwb_electrodes.rel_x, electrodes.x_coord)
assert np.array_equal(nwb_electrodes.rel_y, electrodes.y_coord)

# check Unit
nwb_units = nwb_ephys.units.to_dataframe()

assert len(ephys.CuratedClustering.Unit & curation_key) == len(nwb_units)
assert len(ephys.CuratedClustering.Unit & curation_key & 'cluster_quality_label = "good"') == sum(nwb_units.cluster_quality_label == 'good')

# check waveform

assert np.array_equal(
nwb_units.iloc[15].waveform_mean,
(ephys.WaveformSet.PeakWaveform & curation_key
& 'unit = 15').fetch1('peak_electrode_waveform')
)
4 changes: 2 additions & 2 deletions tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ def test_paramset_insert(kilosort_paramset, pipeline):
method, desc, paramset_hash = (ephys.ClusteringParamSet
& {'paramset_idx': 0}).fetch1(
'clustering_method', 'paramset_desc', 'param_set_hash')
assert method == 'kilosort2'
assert desc == 'Spike sorting using Kilosort2'
assert method == 'kilosort2.5'
assert desc == 'Spike sorting using Kilosort2.5'
assert dict_to_uuid(kilosort_paramset) == paramset_hash
31 changes: 21 additions & 10 deletions tests/test_populate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,19 @@ def test_curated_clustering_populate(curations, pipeline, testdata_paths):
ephys = pipeline['ephys']

rel_path = testdata_paths['npx3A-p1-ks']
curation_key = (ephys.Curation & f'curation_output_dir LIKE "%{rel_path}"'
).fetch1('KEY')
curation_key = _get_curation_key(rel_path, pipeline)
ephys.CuratedClustering.populate(curation_key)
assert len(ephys.CuratedClustering.Unit & curation_key
& 'cluster_quality_label = "good"') == 76

rel_path = testdata_paths['oe_npx3B-ks']
curation_key = (ephys.Curation & f'curation_output_dir LIKE "%{rel_path}"'
).fetch1('KEY')
curation_key = _get_curation_key(rel_path, pipeline)
ephys.CuratedClustering.populate(curation_key)
assert len(ephys.CuratedClustering.Unit & curation_key
& 'cluster_quality_label = "good"') == 68

rel_path = testdata_paths['npx3B-p1-ks']
curation_key = (ephys.Curation & f'curation_output_dir LIKE "%{rel_path}"'
).fetch1('KEY')
curation_key = _get_curation_key(rel_path, pipeline)
ephys.CuratedClustering.populate(curation_key)
assert len(ephys.CuratedClustering.Unit & curation_key
& 'cluster_quality_label = "good"') == 55
Expand All @@ -126,8 +123,7 @@ def test_waveform_populate_npx3B_OpenEphys(curations, pipeline, testdata_paths):
"""
ephys = pipeline['ephys']
rel_path = testdata_paths['oe_npx3B-ks']
curation_key = (ephys.Curation & f'curation_output_dir LIKE "%{rel_path}"'
).fetch1('KEY')
curation_key = _get_curation_key(rel_path, pipeline)
ephys.CuratedClustering.populate(curation_key)
ephys.WaveformSet.populate(curation_key)

Expand All @@ -146,12 +142,27 @@ def test_waveform_populate_npx3B_SpikeGLX(curations, pipeline, testdata_paths):
ephys = pipeline['ephys']

rel_path = testdata_paths['npx3B-p1-ks']
curation_key = (ephys.Curation & f'curation_output_dir LIKE "%{rel_path}"'
).fetch1('KEY')
curation_key = _get_curation_key(rel_path, pipeline)
ephys.CuratedClustering.populate(curation_key)
ephys.WaveformSet.populate(curation_key)

waveforms = np.vstack((ephys.WaveformSet.PeakWaveform
& curation_key).fetch('peak_electrode_waveform'))

assert waveforms.shape == (150, 64)


# ---- HELPER FUNCTIONS ----

def _get_curation_key(output_relative_path, pipeline):
ephys = pipeline['ephys']
ephys_mode = pipeline['ephys_mode']

if ephys_mode == 'no-curation':
EphysCuration = ephys.ClusteringTask
output_dir_attr_name = 'clustering_output_dir'
else:
EphysCuration = ephys.Curation
output_dir_attr_name = 'curation_output_dir'

return (EphysCuration & f'{output_dir_attr_name} LIKE "%{output_relative_path}"').fetch1('KEY')
36 changes: 22 additions & 14 deletions workflow_array_ephys/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from workflow_array_ephys.pipeline import subject, ephys, probe, session
from workflow_array_ephys.paths import get_ephys_root_data_dir
from workflow_array_ephys.pipeline import ephys_mode

from element_array_ephys.readers import spikeglx, openephys
from element_interface.utils import find_root_directory, find_full_path
Expand Down Expand Up @@ -44,10 +45,9 @@ def ingest_sessions(session_csv_path='./user_data/sessions.csv', verbose=True):
session_datetimes, insertions = [], []

# search session dir and determine acquisition software
for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'],
['SpikeGLX', 'OpenEphys']):
ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)]
if len(ephys_meta_filepaths):
for ephys_pattern, ephys_acq_type in zip(['*.ap.meta', '*.oebin'], ['SpikeGLX', 'OpenEphys']):
ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
if ephys_meta_filepaths:
acq_software = ephys_acq_type
break
else:
Expand Down Expand Up @@ -99,20 +99,28 @@ def ingest_sessions(session_csv_path='./user_data/sessions.csv', verbose=True):
probe_insertion_list.extend([{**session_key, **insertion
} for insertion in insertions])

session.Session.insert(session_list)
session.SessionDirectory.insert(session_dir_list)
if verbose:
print(f'\n---- Insert {len(session_list)} entry(s) into session.Session ----')

probe.Probe.insert(probe_list)
if verbose:
print(f'\n---- Insert {len(probe_list)} entry(s) into probe.Probe ----')
probe.Probe.insert(probe_list)

if ephys_mode == 'chronic':
ephys.ProbeInsertion.insert(probe_insertion_list,
ignore_extra_fields=True, skip_duplicates=True)
session.Session.insert(session_list)
session.SessionDirectory.insert(session_dir_list)
if verbose:
print(f'\n---- Insert {len(session_list)} entry(s) into session.Session ----')
print(f'\n---- Insert {len(probe_insertion_list)} entry(s) into ephys.ProbeInsertion ----')
else:
session.Session.insert(session_list)
session.SessionDirectory.insert(session_dir_list)
ephys.ProbeInsertion.insert(probe_insertion_list)
if verbose:
print(f'\n---- Insert {len(session_list)} entry(s) into session.Session ----')
print(f'\n---- Insert {len(probe_insertion_list)} entry(s) into ephys.ProbeInsertion ----')

ephys.ProbeInsertion.insert(probe_insertion_list)
if verbose:
print(f'\n---- Insert {len(probe_insertion_list)} entry(s) into '
+ 'ephys.ProbeInsertion ----')
print('\n---- Successfully completed ingest_subjects ----')
print('\n---- Successfully completed workflow_array_ephys/ingest.py ----')


if __name__ == '__main__':
Expand Down
14 changes: 13 additions & 1 deletion workflow_array_ephys/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import datajoint as dj
import os
from element_animal import subject
from element_lab import lab
from element_session import session
from element_array_ephys import probe, ephys
from element_array_ephys import probe

from element_animal.subject import Subject
from element_lab.lab import Source, Lab, Protocol, User, Project
Expand All @@ -15,6 +16,17 @@

db_prefix = dj.config['custom'].get('database.prefix', '')

# ------------- Import the configured "ephys mode" -------------
ephys_mode = os.getenv('EPHYS_MODE',
dj.config['custom'].get('ephys_mode', 'acute'))
if ephys_mode == 'acute':
from element_array_ephys import ephys
elif ephys_mode == 'chronic':
from element_array_ephys import ephys_chronic as ephys
elif ephys_mode == 'no-curation':
from element_array_ephys import ephys_no_curation as ephys
else:
raise ValueError(f'Unknown ephys mode: {ephys_mode}')

# Activate "lab", "subject", "session" schema ---------------------------------

Expand Down