diff --git a/amplitude_finder.py b/amplitude_finder.py index 8189f41..79f2991 100644 --- a/amplitude_finder.py +++ b/amplitude_finder.py @@ -4,6 +4,11 @@ import sys import numpy as np import pickle +import time + +# def find_amps_dat(): + +# def find_amps_continuous(): available_cpu_count = len(psutil.Process().cpu_affinity()) os.environ["MKL_NUM_THREADS"] = str(available_cpu_count) @@ -22,8 +27,8 @@ cluster = clusterbank['good_units'][cluster_num] spike_times = cluster['times'] -dat_loc = os.path.join(home_dir, '100_CHs.dat') -amps, max_cluster_chan = find_amplitudes(dat_loc, 64, spike_times) +#dat_loc = os.path.join(home_dir, '100_CHs.dat') +amps, max_cluster_chan = find_amplitudes(home_dir, 64, spike_times) amp_and_chan = {'amps':amps, 'max_cluster_chan':max_cluster_chan} diff --git a/amplitude_finder.sh b/amplitude_finder.sh index 9e1410f..2535121 100644 --- a/amplitude_finder.sh +++ b/amplitude_finder.sh @@ -6,7 +6,7 @@ #SBATCH --error=/home/camp/warnert/outputs/error_%a.txt #SBATCH --ntasks=1 #SBATCH --time=24:00:00 -#SBATCH --mem=30G +#SBATCH --mem=100G #SBATCH --partition=compute # diff --git a/clusterbank_maker.py b/clusterbank_maker.py index 0e46ddd..14be250 100644 --- a/clusterbank_maker.py +++ b/clusterbank_maker.py @@ -116,13 +116,16 @@ def make_clusterbank_basic(home_dir, *, dump=True, kilosort2=False): return cluster_dict -def find_amplitudes(dat_loc, num_of_chans, spike_times, *, bitvolts=0.195, order='F'): - dat_file = np.memmap(dat_loc, dtype=np.int16) - time_length = int(len(dat_file)/num_of_chans) - dat_file = dat_file.reshape((num_of_chans, time_length), order=order) +def find_amplitudes(data_loc, num_of_chans, spike_times, *, dat=False, bitvolts=0.195, order='F'): + if dat: + dat_file = np.memmap(data_loc, dtype=np.int16) + time_length = int(len(dat_file)/num_of_chans) + datas = dat_file.reshape((num_of_chans, time_length), order=order) + else: + datas = [oe.loadContinuous2(os.path.join(home_dir, '100_CH%d.continuous' % i)) for i in range(1, num_of_chans+1)] cluster_spikes = [] for i in tqdm(range(num_of_chans)): - cluster_spikes.append(find_cluster_spikes(dat_file[i], spike_times)[1]) + cluster_spikes.append(find_cluster_spikes(datas[i], spike_times)[1]) max_cluster_chan = find_max_chan(cluster_spikes) amps = [min(i)*bitvolts for i in cluster_spikes[max_cluster_chan]] return amps, max_cluster_chan @@ -135,10 +138,15 @@ def find_max_chan(cluster_spikes): max_cluster_chan = np.argmin([min(i) for i in mean_cluster_spikes]) return max_cluster_chan -def make_clusterbank_full(home_dir, num_of_chans, *, bitvolts=0.195, order='F', kilosort2=False, dump=True, dat_name='100_CHs.dat'): +def make_clusterbank_full(home_dir, num_of_chans, *, bitvolts=0.195, order='F', kilosort2=False, dump=True, dat=False, dat_name='100_CHs.dat'): ''' Makes a clusterbank with all the information ''' + if dat: + data_loc = os.path.join(home_dir, dat_name) + else: + data_loc = home_dir + clusterbank_basic = make_clusterbank_basic(home_dir, dump=False, kilosort2=kilosort2) print('Done making basic clusterbank') for unit_type in clusterbank_basic: @@ -148,9 +156,8 @@ def make_clusterbank_full(home_dir, num_of_chans, *, bitvolts=0.195, order='F', print('Doing cluster', cluster_num) cluster = clusterbank_basic[unit_type][cluster_num] spike_times = cluster['times'] - dat_loc = os.path.join(home_dir, dat_name) print('Finding amps for', len(spike_times), 'spikes') - amps, max_cluster_chan = find_amplitudes(dat_loc, num_of_chans, spike_times, bitvolts=bitvolts, order=order) + amps, max_cluster_chan = find_amplitudes(data_loc, num_of_chans, spike_times, dat = dat, bitvolts=bitvolts, order=order) cluster['amps'] = amps cluster['max_chan'] = max_cluster_chan if dump: