Skip to content

Commit

Permalink
Merge pull request #1 from jenniferColonell/jic_edit
Browse files Browse the repository at this point in the history
Adding comments throughout, some minor corrections
  • Loading branch information
AugustineY07 authored Aug 25, 2023
2 parents a97900d + c7e449c commit 5648aef
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 47 deletions.
33 changes: 20 additions & 13 deletions Example/Example_run.m
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
% An example of tracking units in animal AL032 shank 1

clear;

%----------add packages----------
addpath(genpath('C:\Users\labadmin\Desktop\Neuron Tracking Pipeline\User version')) %NEED CHANGE
addpath(genpath('D:\Data\Pipeline\npy\npy-matlab'))
% EDIT THESE PATHS FOR YOUR SYSTEM
addpath(genpath('C:\Users\colonellj\Documents\matlab_ephys\AY_match_current\Neuron_Tracking')) %Add this repo to MATLAB path
addpath(genpath('C:\Users\colonellj\Documents\npy-matlab-master')) %Add npy-matlab to MATLAB path

%----------define parameter and path----------
input.input_path = 'C:\Users\labadmin\Desktop\Neuron Tracking Pipeline\User version\'; %main directory, NEED CHANGE
input.EMD_path = fullfile(input.input_path,'EMD_input\'); %EMD input directory, NEED CHANGE
%parent directory of the data (see Example directory in this repo for organizaiont) -- EDIT FOR YOUR SYSTEM
input.input_path = 'C:\Users\colonellj\Documents\matlab_ephys\AY_match_current\Neuron_Tracking\Example';
input.EMD_path = input.input_path;
% Input data
input.fs = 30000; %acquisition rate
input.ts = 82; %wf time samples
input.l2_weights = 1500;
input.threshold = 10;
input.dim_mask = logical([1,1,1,0,0,0,0,0,0,1]);
input.dim_mask_physical = logical([1,1,1,0,0,0,0,0,0,0]);
input.dim_mask_wf = logical([0,0,0,0,0,0,0,0,0,1]);
input.threshold = 10; % distance threshold for calling matches as real
input.l2_weights = 1500; % weight for waveform difference, see comments in Pipeline\EMD_matlab\weighted_gdf_nt
input.dim_mask = logical([1,1,1,0,0,0,0,0,0,1]); % used to calculate full distance metric
input.dim_mask_physical = logical([1,1,1,0,0,0,0,0,0,0]); % used to calculate 'position only' part of distance metric
input.dim_mask_wf = logical([0,0,0,0,0,0,0,0,0,1]); % used to calculate 'waveform only' part of distance metric
input.chan_pos_name = 'channel_positions.npy';
input.wf_name = 'ksproc_mean_waveforms.npy';
input.KSLabel_name = 'cluster_KSLabel.tsv';
input.validation = 0;

numData = 5;

Expand All @@ -26,8 +31,8 @@
%----------Unit tracking----------
% Find match of all datasets (default: day n and day n+1, can be changed to track between non-consecutive datasets)
for id = 1:numData-1
input.data_path1 = ['D',num2str(id)]; % frist dataset, NEED CHANGE
input.data_path2 = ['D',num2str(id+1)]; % second dataset, NEED CHANGE
input.data_path1 = ['D',num2str(id)]; % first dataset, CHANGE to track non-consecutively
input.data_path2 = ['D',num2str(id+1)]; % second dataset, CHANGE to track non-consecutively
input.result_path = fullfile(input.input_path,['result',num2str(id),num2str(id+1)]); %result directory
input.input_name = ['input',num2str(id),'.mat'];
input.input_name_post = ['input_post',num2str(id),'.mat'];
Expand All @@ -44,14 +49,16 @@
all_input(id) = load(fullfile(input.input_path,['result',num2str(id),num2str(id+1)], "Input.mat"));
all_output(id) = load(fullfile(input.input_path,['result',num2str(id),num2str(id+1)],'Output.mat'));
end

save(fullfile(input.input_path,'all.mat'),"all_input","all_output")
[chain_all,z_loc,len] = chain_summary(all_input,all_output,numData);
[chain_all,z_loc,len] = chain_summary(all_input,all_output,numData,input.input_path);




%----------Plot chains of interest (waveform, firing rate, original location, drift-corrected location, L2)----------
full_chain = chain_all(len == numData,:); %find chains with length across all datasets
[L2_weight,fr_all,fr_change,x_loc_all,z_loc_all] = chain_stats(all_input,all_output,full_chain,numData);
[L2_weight,fr_all,fr_change,x_loc_all,z_loc_all] = chain_stats(all_input,all_output,full_chain,numData,input.input_path);

numChain = size(full_chain,1);
ichain = 1; %which chain to plot, please enter a number between 1 and numChain as input, NEED CHANGE
Expand Down
16 changes: 8 additions & 8 deletions Example/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
An example with AL032 shank 1 data used in the manuscript.

To run the example, please follow the steps below:
1. Download all files in the 'Example' and 'Pipeline' directories
2. Add both to the MATLAB path
3. Run 'Example_run.m'
1. Clone the repo, or download a zip folder
2. Add the parent folder (.../Neuron_Tracking) to the MATLAB folder with subfolders
3. Edit paths in 'Example_run.m' and run.

Input required to run the pipleline with your own data:
- channel_map.npy
- channel_positions.npy
- cluster_KSLabel.tsv
- ksproc_mean_waveforms.npy: post-processed waveforms using JRClust (https://github.com/jenniferColonell/JRCLUST)
- metrics.csv
- channel_map.npy: standard KS2.5 output
- channel_positions.npy: standard KS2.5 output
- cluster_KSLabel.tsv: standard KS2.5 output
- ksproc_mean_waveforms.npy: mean waveforms in (nUnit x nChan x nTimepoints), saved as npy. In our examples we used C_Waves to calculate the mean waveforms.
- metrics.csv: unit metrics; can be generated with ecephys_pipeline or other tool; only the first two columns (cluster ID and firing rate) are required. Note that cluster ID in the metrics table is 0-based; elsewhere cluster ID is zero based.

4 changes: 4 additions & 0 deletions Pipeline/EMD_matlab/emd_nt.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
% Addition parameters are:
% mw1, mw2: (nUnit x nSite) arrays, mean waveforms for the units included in F1 and F2
% chan_pos: (nSite x 2) array of x and z coordinates of sites on the probe.
% dim_mask: (10 x 1) array for selecting which quantities in weighted_gdf_nt.m are
% included in distance metric
% l2_weight: weight for the waveform similarity portion of the EMD distance
% other weights are hard coded in weighted_gdf_nt.m
% Note that F1 and F2 must be recorded on identical sites of the probe.
%
% EMD Earth Mover's Distance between two signatures
Expand Down
2 changes: 0 additions & 2 deletions Pipeline/EMD_matlab/weighted_gdf_nt.m
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

% further modified to allow calling l2 distance between 2D waveforms from
% this function.
% TODO -- change weights from hard coded here to passing in as another
% variable; can then dispense with the mask, just send in 0 weights.

dim_weights(1) = 0.1; % centroid x, 1/um^2
dim_weights(2) = 0.1; % centroid z, 1/um^2
Expand Down
3 changes: 2 additions & 1 deletion Pipeline/NT_main.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
% This is the main function of the algorithm
% Input: kilosort cluster label, channel map, mean waveforms,
% Output: Unit match assignment
% For more comparisons, users need to write their own loops
% For more comparisons, users need to write their own loops; see
% Example_run.m
function NT_main(input,chan_pos,mwf1,mwf2)

% Estimate location
Expand Down
2 changes: 1 addition & 1 deletion Pipeline/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@

See README in the Example folder for explanations of input data.
189 changes: 189 additions & 0 deletions Pipeline/Supplementary/ks2_working_toBinary.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
function ks2_working_toBinary( varargin )

% Read data preprocessed by KS2 (whitenend, post datashift) and write out
% as a standard binary.

% IMPORTANT NOTE: the preprocessed file only contains channels that KS2
% uses for sorting! Channels with connected=0 in the chanMap file OR
% eliminated due to low spike count (ops.minfr_goodchannels > 0) will not
% be included. In that case, use either the info in the rez file or the phy
% output to specify the correct site geometry and channel map for sorting.

% 2nd important note:
% Make sure you set the correct version of kilosort. Supported versions are
% '2.0' and '2.5' and '3.0'
KSver = '2.5'
triggerStr = 'tcat'; % set to t0 for data not processed by CatGT, tcat for processed data
useRezFroc = 0; % set to 1 if the path to fproc in rez.mat is correct, 0 if it needs to be selected separately

lenTS = numel(triggerStr);

% make sure the rez file is pointing to the current preprocessed data file.
if isempty(varargin )
[rezName, rezPath] = uigetfile('*.mat','Select rez file from KS sort');
rezFullPath = fullfile( rezPath, rezName );
modStr = 'new';
[mmName, mmPath] = uigetfile('*.meta','Select original metadata file');
modelMetaFullPath = fullfile(mmPath, mmName);
if useRezFroc ~= 1
[fprocName, fprocPath] = uigetfile({'*.bin','*.dat'}, 'Select fproc file');
fprocFullPath = fullfile(fprocPath,fprocName);
end

else
% called with rezFullPath
inputCell = varargin(1);
rezFullPath = inputCell{1};
inputCell = varargin(2);
modStr = inputCell{1};
inputCell = varargin(3);
modelMetaFullPath = inputCell{1};
if useRezFroc ~= 1
inputCell = varargin(4);
fprocFullPath = inputCell{1};
end

end

% build output name, get copy of metaName
load( rezFullPath );
ops = rez.ops;

if useRezFroc
fprocFullPath = ops.fproc;
end

[binPath, binName, binExt] = fileparts(rez.ops.fbinary);
tPos = strfind(binName, '.imec');
temp = extractBetween(binName,1,tPos-(lenTS+1));
baseName = temp{1};

suffix = extractAfter(binName,triggerStr);
outName = sprintf( '%s%s%s%s', baseName, modStr, suffix, binExt);
if useRezFroc
% 'Typically' want the output to go in the directory with the input
[outPath, ~, ~] = fileparts(binPath);
outFullPath = fullfile(binPath, outName);
else
% 'Typically' working from a copy of the processed file, and the
% unwhitened file goes there
[outPath, ~, ~] = fileparts(fprocFullPath);
outFullPath = fullfile(outPath,outName);
end
metaName = sprintf( '%s%s', extractBefore(outName,'bin'), 'meta');

% KS2 creates a temporary whitened file consisting of batches of NT time
% points stored as NT rows by NChan columns (so each column the time trace
% for a single channel). These blocks are overlapped in time by nt.buff
% timepoints, which get trimmed off in analysis to avoid some filtering
% artifacts. To translate these data to "standard binary format, need to
% read each block, trim off the buffer, unwhiten and rescale, and finally
% tranpose to Nchan rows by NT-ntbuff columns.

% KS2.5 aims to make a readable binary from it's datashifted data. Rather
% than overlapping the batches, it reads in some extra points for filtering
% and then trims them back off. ops.ntbuff = ntb
% read NTbuff = NT + 3*ntb points
% for a standard batch (neither first nor last) read:
% --ntb points overlapping the last batch
% --NT points that belong to this batch
% --2*ntb more points; first set of ntb will be blended with next
% batch, 2nd ntb just filtering buffer
% After fitering, points ntb+1:ntb are blended with "dat_prev" which is
% NT+ntb+1:NT+2*ntb saved from the previous batch.
% Batch zero gets ntb zeros prepended to its data and blended with the
% initialized dat_prev (also zeros).
% After filtering, the data is whitened and then transposed bacyk to NChan
% rows by NT columns to save. When these batches are read for sorting in
% learnTemplates, the data is transposed after reading (new in KS25)




fid = fopen(fprocFullPath, 'r');
fidW = fopen(outFullPath, 'w'); % open for writing processed data, transposed

if strcmp(KSver,'2.0')
batchstart = 0:ops.NT:ops.NT*ops.Nbatch; % batches start at these timepoints
for ibatch = 1:ops.Nbatch
offset = 2 * ops.Nchan*batchstart(ibatch); % binary file offset in bytes
fseek(fid, offset, 'bof');
dat = fread(fid, [ops.NT ops.Nchan], '*int16');
% Due to clumsy arithmetic in KS, the first two batches overlap by
% 2*ops.ntbuff, while later batches overlap by 1*ntbuff
% trim samples of the end of the current batch accordingly
if ibatch == 1
dat = dat(1:ops.NT-2*ops.ntbuff, :);
else
dat = dat(1:ops.NT-ops.ntbuff, :);
end
% rescale by the average of the diagonal of the whitening
% matrix. This will keep the voltage values in range during the 2nd
% sort.
%norm = mean(diag(rez.Wrot));
%dat = int16(single(dat)./norm);
dat = int16(single(dat)/rez.Wrot);
fwrite(fidW, dat', 'int16'); % write transposed batch to binary, these chunks are in order
end
elseif strcmp(KSver,'2.5') || strcmp(KSver,'3.0')
%Batches already stored end to end and transposed. Just need to read,
%transpose, unwhiten, tranpose back and store
batchstart = 0:ops.NT:ops.NT*ops.Nbatch; % batches start at these timepoints
for ibatch = 1:ops.Nbatch
offset = 2 * ops.Nchan*batchstart(ibatch); % binary file offset in bytes
fseek(fid, offset, 'bof');
dat = fread(fid, [ops.Nchan ops.NT], '*int16');
% if skipping unwhitening, rescale by the average of the diagonal of the whitening
% matrix. This will keep the voltage values in range during the 2nd
% sort.
%norm = mean(diag(rez.Wrot));
%dat = int16(single(dat)./norm);
% if unwhitening, skip the rescale step. Transpose to NT by Nchan
% before unwhitening
dat = int16(single(dat')/rez.Wrot);
fwrite(fidW, dat', 'int16'); % write transposed batch to binary, these chunks are in order
end
else
fprintf( 'Unknown version of Kilsort.\n');

end

fclose(fid);
fclose(fidW);

if strlength(modelMetaFullPath) > 0

fp = dir(outFullPath);
newTags = cell(4,1);
newTag{1} = sprintf('%s%d', 'fileSizeBytes=', fp.bytes);
newTag{2} = sprintf('%s%d', 'nSavedChans=', ops.Nchan);
newTag{3} = sprintf('%s%d%s', 'snsApLfSy=', ops.Nchan, ',0,0');
newTag{4} = sprintf('%s%d', 'snsSaveChanSubset=0:',ops.Nchan-1);
repTags = cell(4,1);
repTags{1} = 'fileSizeBytes';
repTags{2} = 'nSavedChans';
repTags{3} = 'snsApLfSy';
repTags{4} = 'snsSaveChanSubset';

fmodel = fopen( modelMetaFullPath, 'r');
fmeta = fopen( fullfile(outPath, metaName), 'w');

tline = fgetl(fmodel);
while ischar(tline)
currTag = extractBefore(tline,'=');
tagFound = find(strcmp(repTags, currTag));
if isempty(tagFound)
%copy over this line as is
fprintf(fmeta, '%s\n', tline );
else
fprintf('found: %s\n', repTags{tagFound} );
fprintf(fmeta, '%s\n', newTag{tagFound} );
end
tline = fgetl(fmodel);
end
fclose(fmeta);
fclose(fmodel);
end
fprintf( 'Output file has %d channels\n', ops.Nchan );

end
4 changes: 2 additions & 2 deletions Pipeline/chain/chain_stats.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
% calculate L2, FR, and locations of all chains
function [L2_weight,fr_all,fr_change,x_loc,z_loc] = chain_stats(all_input,all_output,full_chain,numData)
function [L2_weight,fr_all,fr_change,x_loc,z_loc] = chain_stats(all_input,all_output,full_chain,numData,output_path)

for ichain = 1:size(full_chain,1)
for id = 1:numData-1
Expand Down Expand Up @@ -68,5 +68,5 @@
end
end

save('C:\Users\labadmin\Desktop\Neuron Tracking Pipeline\User version\chain_stats.mat','full_chain','L2_weight','fr_all','fr_change','z_loc','x_loc')
save(fullfile(output_path,'chain_stats.mat'),'full_chain','L2_weight','fr_all','fr_change','z_loc','x_loc')
end
4 changes: 2 additions & 2 deletions Pipeline/chain_summary.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
% This is the second function of the algorithm
% Input: Unit assignment from multiple datasets
% Output: Summary of chains within distance threshold
function [chain_all,z_loc,len] = chain_summary(all_input,all_output,numData)
function [chain_all,z_loc,len] = chain_summary(all_input,all_output,numData,output_path)

% Summarize chains
numChain = 200; %set to a number large number
Expand Down Expand Up @@ -51,5 +51,5 @@
len = sum(chain_all ~= 0,2); %chain lengths
end

save('C:\Users\labadmin\Desktop\Neuron Tracking Pipeline\User version\chain_summary.mat', "chain_all", "z_loc", "len");
save(fullfile(output_path,'chain_summary.mat'), "chain_all", "z_loc", "len");
end
31 changes: 18 additions & 13 deletions Pipeline/create_input/wf_metric/wave_metrics.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
function meas_out = wave_metrics(mw, chan_pos, input)
% Calculate an array of simple waveform metrics from the mean waveforms for
% each unit.
% 1D waveform characteristics based on allen ecephys pipeline
% Position fitting based on Boussard 2D fit to 1/R

[nUnit, nChan, nt] = size(mw);
pp_all = squeeze(max(mw,[],3)-min(mw,[],3));
Expand Down Expand Up @@ -106,20 +110,21 @@
% 2D waveform metrics
pp_unit = squeeze(pp_all(i,:))';

% Julien Boussard - style fit of peak-to-peak voltage vs position
% if background sub pp_unit > 60 uV, attempt a fit of the
% background subtracted pp_all
if max(squeeze(pp_all(i,:))) > 60
fitvals = fit_loc(i, pp_all, chan_pos);
fitX = fitvals(1);
fitZ = fitvals(2);
fitY = fitvals(3);
else
fitX = chan_pos(pk_chan(i),1);
fitZ = chan_pos(pk_chan(i),2);
fitY = -1; % a marker for no fit
end
% Julien Boussard - style fit of peak-to-peak voltage vs position
% if background sub pp_unit > 60 uV, attempt a fit of the
% background subtracted pp_all
if max(squeeze(pp_all(i,:))) > 60
fitvals = fit_loc(i, pp_all, chan_pos);
fitX = fitvals(1);
fitZ = fitvals(2);
fitY = fitvals(3);
else
fitX = chan_pos(pk_chan(i),1);
fitZ = chan_pos(pk_chan(i),2);
fitY = -1; % a marker for no fit
end

% calculate spread of waveforms in z
sp_thresh = 0.2*max(pp_unit);
chan_above = pp_unit > sp_thresh;
zmax = max(chan_pos(chan_above,2));
Expand Down
Loading

0 comments on commit 5648aef

Please sign in to comment.