Skip to content
This repository was archived by the owner on Dec 8, 2023. It is now read-only.

activate electrode-localization #55

Merged
merged 4 commits into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ element-array-ephys==0.1.0b0
element-lab>=0.1.0b0
element-animal==0.1.0b0
element-session==0.1.0b0
element-event @ git+https://github.com/datajoint/element-event.git
element-interface @ git+https://github.com/datajoint/element-interface.git
ipykernel==6.0.1
58 changes: 12 additions & 46 deletions workflow_array_ephys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class Trial(dj.Part):
class SpikesAlignment(dj.Computed):
definition = """
-> SpikesAlignmentCondition
-> ephys.CuratedClustering
"""

class AlignedTrialSpikes(dj.Part):
Expand All @@ -52,66 +51,33 @@ class UnitPSTH(dj.Part):

def make(self, key):
unit_keys, unit_spike_times = (ephys.CuratedClustering.Unit & key).fetch('KEY', 'spike_times', order_by='unit')

trial_keys, trial_starts, trial_ends = (trial.Trial & (SpikesAlignmentCondition.Trial & key)).fetch(
'KEY', 'trial_start_time', 'trial_stop_time', order_by='trial_id')

bin_size = (SpikesAlignmentCondition & key).fetch1('bin_size')

alignment_spec = (event.AlignmentEvent & key).fetch1()
trialized_event_times = trial.get_trialized_alignment_event_times(
key, trial.Trial & (SpikesAlignmentCondition.Trial & key))

min_limit = (trialized_event_times.event - trialized_event_times.start).max()
max_limit = (trialized_event_times.end - trialized_event_times.event).max()

# Spike raster
aligned_trial_spikes = []
units_spike_raster = {u['unit']: {**key, **u, 'aligned_spikes': []} for u in unit_keys}
min_limit, max_limit = np.Inf, -np.Inf
for trial_key, trial_start, trial_stop in zip(trial_keys, trial_starts, trial_ends):
alignment_event_time = (event.Event & key & {'event_type': alignment_spec['alignment_event_type']}
& f'event_start_time BETWEEN {trial_start} AND {trial_stop}')
if alignment_event_time:
# if there are multiple of such alignment event, pick the last one in the trial
alignment_event_time = alignment_event_time.fetch(
'event_start_time', order_by='event_start_time DESC', limit=1)[0]
else:
for _, r in trialized_event_times.iterrows():
if np.isnan(r.event):
continue

alignment_start_time = (event.Event & key & {'event_type': alignment_spec['start_event_type']}
& f'event_start_time < {alignment_event_time}')
if alignment_start_time:
# if there are multiple of such start event, pick the most immediate one prior to the alignment event
alignment_start_time = alignment_start_time.fetch(
'event_start_time', order_by='event_start_time DESC', limit=1)[0]
alignment_start_time = max(alignment_start_time, trial_start)
else:
alignment_start_time = trial_start

alignment_end_time = (event.Event & key & {'event_type': alignment_spec['end_event_type']}
& f'event_start_time > {alignment_event_time}')
if alignment_end_time:
# if there are multiple of such start event, pick the most immediate one following the alignment event
alignment_end_time = alignment_end_time.fetch(
'event_start_time', order_by='event_start_time', limit=1)[0]
alignment_end_time = min(alignment_end_time, trial_stop)
else:
alignment_end_time = trial_stop

alignment_event_time += alignment_spec['alignment_time_shift']
alignment_start_time += alignment_spec['start_time_shift']
alignment_end_time += alignment_spec['end_time_shift']

min_limit = min(alignment_start_time - alignment_event_time, min_limit)
max_limit = max(alignment_end_time - alignment_event_time, max_limit)

alignment_start_time = r.event - min_limit
alignment_end_time = r.event + max_limit
for unit_key, spikes in zip(unit_keys, unit_spike_times):
aligned_spikes = spikes[(alignment_start_time <= spikes)
& (spikes < alignment_end_time)] - alignment_event_time
aligned_trial_spikes.append({**key, **unit_key, **trial_key, 'aligned_spike_times': aligned_spikes})
& (spikes < alignment_end_time)] - r.event
aligned_trial_spikes.append({**key, **unit_key, **r.trial_key, 'aligned_spike_times': aligned_spikes})
units_spike_raster[unit_key['unit']]['aligned_spikes'].append(aligned_spikes)

# PSTH
for unit_spike_raster in units_spike_raster.values():
spikes = np.concatenate(unit_spike_raster['aligned_spikes'])

psth, edges = np.histogram(spikes, bins=np.arange(min_limit, max_limit, bin_size))
psth, edges = np.histogram(spikes, bins=np.arange(-min_limit, max_limit, bin_size))
unit_spike_raster['psth'] = psth / len(unit_spike_raster.pop('aligned_spikes')) / bin_size
unit_spike_raster['psth_edges'] = edges[1:]

Expand Down
17 changes: 17 additions & 0 deletions workflow_array_ephys/paths.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datajoint as dj
import pathlib
from element_interface.utils import find_full_path


def get_ephys_root_data_dir():
Expand All @@ -9,3 +11,18 @@ def get_session_directory(session_key: dict) -> str:
from .pipeline import session
session_dir = (session.SessionDirectory & session_key).fetch1('session_dir')
return session_dir


def get_electrode_localization_dir(probe_insertion_key: dict) -> str:
from .pipeline import ephys
acq_software = (ephys.EphysRecording & probe_insertion_key).fetch1('acq_software')

if acq_software == 'SpikeGLX':
spikeglx_meta_filepath = pathlib.Path((ephys.EphysRecording.EphysFile & probe_insertion_key
& 'file_path LIKE "%.ap.meta"').fetch1('file_path'))
probe_dir = find_full_path(get_ephys_root_data_dir(), spikeglx_meta_filepath.parent)
elif acq_software == 'Open Ephys':
probe_path = (ephys.EphysRecording.EphysFile & probe_insertion_key).fetch1('file_path')
probe_dir = find_full_path(get_ephys_root_data_dir(), probe_path)

return probe_dir
32 changes: 27 additions & 5 deletions workflow_array_ephys/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from element_animal import subject
from element_lab import lab
from element_session import session
from element_trial import trial, event
from element_event import trial, event
from element_array_ephys import probe
from element_electrode_localization import coordinate_framework, electrode_localization

from element_animal.subject import Subject
from element_lab.lab import Source, Lab, Protocol, User, Project
from element_session.session_with_datetime import Session

from .paths import get_ephys_root_data_dir, get_session_directory
from .paths import get_ephys_root_data_dir, get_session_directory, get_electrode_localization_dir

if 'custom' not in dj.config:
dj.config['custom'] = {}
Expand All @@ -30,8 +31,8 @@
raise ValueError(f'Unknown ephys mode: {ephys_mode}')

__all__ = ['subject', 'lab', 'session', 'trial', 'event', 'probe', 'ephys', 'Subject',
'Source', 'Lab', 'Protocol', 'User', 'Project', 'Session',
'get_ephys_root_data_dir', 'get_session_directory']
'Source', 'Lab', 'Protocol', 'User', 'Project', 'Session', 'coordinate_framework', 'electrode_localization',
'get_ephys_root_data_dir', 'get_session_directory', 'get_electrode_localization_dir']


# Activate "lab", "subject", "session" schema ---------------------------------
Expand All @@ -43,7 +44,10 @@
Experimenter = lab.User
session.activate(db_prefix + 'session', linking_module=__name__)

trial.activate(db_prefix + 'trial', db_prefix + 'event', linking_module= __name__)

# Activate "event" and "trial" schema ---------------------------------

trial.activate(db_prefix + 'trial', db_prefix + 'event', linking_module=__name__)


# Declare table "SkullReference" for use in element-array-ephys ---------------
Expand All @@ -61,3 +65,21 @@ class SkullReference(dj.Lookup):
ephys.activate(db_prefix + 'ephys',
db_prefix + 'probe',
linking_module=__name__)

# Activate "electrode-localization" schema ------------------------------------

ProbeInsertion = ephys.ProbeInsertion
Electrode = probe.ProbeType.Electrode

electrode_localization.activate(db_prefix + 'electrode_localization',
db_prefix + 'ccf',
linking_module=__name__)

ccf_id = 0
voxel_resolution = 100

if not (coordinate_framework.CCF & {'ccf_id': ccf_id}):
coordinate_framework.load_ccf_annotation(
ccf_id=ccf_id, version_name='ccf_2017', voxel_resolution=voxel_resolution,
nrrd_filepath=f'./data/annotation_{voxel_resolution}.nrrd',
ontology_csv_filepath='./data/query.csv')