Skip to content

Commit 08c271c

Browse files
committed
Completed and tested calibration portion of pipeline
1 parent 0daa0b9 commit 08c271c

File tree

9 files changed

+822
-97
lines changed

9 files changed

+822
-97
lines changed

conf/base/catalog.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,16 @@ prefixed_channels:
9898
# compress: 3
9999
# path: data/03_primary/CC01/ccCenterOut/downsampled_signals
100100
# filename_suffix: ".lz4"
101-
# layer: intermediate
101+
# layer: primary
102102

103+
calibration_statistics_pkl:
104+
type: PartitionedDataSet
105+
dataset:
106+
type: pickle.PickleDataSet
107+
backend: pickle
108+
path: data/03_primary/CC01/ccCenterOut/calibration_statistics
109+
filename_suffix: ".pkl"
110+
layer: primary
103111

104112

105113

conf/base/parameters.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
patient_id: 'CC01'
2+
current_experiment: 'center_out'
3+
current_calibration: 'calibration'
24
bci_states:
35
CC01:
46
center_out: ['StimulusCode', 'cursorX', 'cursorY', 'Baseline', 'ResultCode']
7+
calibration: ['StimulusCode']
58
gain: 0.25
69
data_preprocessing:
710
channel_labelling:
@@ -34,4 +37,9 @@ sessions:
3437
'20221027': ['S08', 'S09']
3538
'20221104': ['S09', 'S10']
3639
'20221108': ['S08', 'S09']
37-
'20221111': ['S07', 'S08']
40+
'20221111': ['S07', 'S08']
41+
calibration:
42+
'20221027': ['S01', 'S02', 'S03']
43+
'20221104': ['S01', 'S02']
44+
'20221108': ['S01', 'S02']
45+
'20221111': ['S01', 'S02']

notebooks/Data Ingestion Notebook.ipynb

Lines changed: 636 additions & 43 deletions
Large diffs are not rendered by default.

src/decoding_pipeline/pipelines/data_generation/nodes.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,25 @@
66
from scripts.convert_bci_to_hdf5 import convert_bcistream, convert_dat
77

88

9-
def generate_center_out_hdf5_dataset(bcistreams, selected_sessions, patient_id):
10-
sessions_dict = selected_sessions[patient_id]['center_out']
9+
def generate_hdf5_dataset(bcistreams, selected_sessions, patient_id, current_experiment, current_run_type):
10+
sessions_dict = selected_sessions[patient_id][current_experiment]
11+
calibration_dict = selected_sessions[patient_id]['calibration']
1112

1213
for partition_key, partition_load_func in bcistreams.items():
13-
1414
continue_loop = True
1515
for paradigm_key, date_session_dict in sessions_dict.items():
1616
for date_key, sessions_list in date_session_dict.items():
17-
if date_key in partition_key and partition_key.split('_')[-1] in sessions_list:
18-
continue_loop = False
17+
if current_run_type == 'calibration':
18+
if date_key in partition_key:
19+
calibration_sessions_list = calibration_dict.get(date_key, [])
20+
if len(calibration_sessions_list):
21+
if partition_key.split('_')[-1] in calibration_sessions_list:
22+
continue_loop = False
23+
else:
24+
continue_loop = False
25+
else:
26+
if date_key in partition_key and partition_key.split('_')[-1] in sessions_list:
27+
continue_loop = False
1928

2029
if continue_loop:
2130
continue
@@ -29,26 +38,3 @@ def generate_center_out_hdf5_dataset(bcistreams, selected_sessions, patient_id):
2938
convert_dat(partition_data, h5filename=filename, add_everything=True)
3039

3140
return {}
32-
33-
def generate_calibration_hdf5_dataset(bcistreams, selected_sessions, patient_id):
34-
sessions_dict = selected_sessions[patient_id]['center_out']
35-
36-
for partition_key, partition_load_func in bcistreams.items():
37-
continue_loop = True
38-
for paradigm_key, date_session_dict in sessions_dict.items():
39-
for date_key, sessions_list in date_session_dict.items():
40-
if date_key in partition_key:
41-
continue_loop = False
42-
43-
if continue_loop:
44-
continue
45-
46-
partition_data = partition_load_func()
47-
48-
filename = partition_data.filename.replace(".dat", ".hdf5")
49-
50-
print(filename)
51-
52-
convert_dat(partition_data, h5filename=filename, add_everything=True)
53-
54-
return {}

src/decoding_pipeline/pipelines/data_generation/pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
generated using Kedro 0.18.3
44
"""
55

6-
from .nodes import generate_center_out_hdf5_dataset, generate_calibration_hdf5_dataset
6+
from .nodes import generate_hdf5_dataset
77

88
from kedro.pipeline import Pipeline, node
99

1010
def create_pipeline(**kwargs) -> Pipeline:
1111
return Pipeline([
1212
node(
13-
func=generate_center_out_hdf5_dataset,
14-
inputs=["center_out_dat", "params:sessions", "params:patient_id"],
13+
func=generate_hdf5_dataset,
14+
inputs=["center_out_dat", "params:sessions", "params:patient_id", "params:current_experiment", "params:current_experiment"],
1515
outputs="center_out_hdf5",
1616
name="convert_center_out_dat_to_hdf5_node",
1717
),
1818
node(
19-
func=generate_calibration_hdf5_dataset,
20-
inputs=["calibration_dat", "params:sessions", "params:patient_id"],
19+
func=generate_hdf5_dataset,
20+
inputs=["calibration_dat", "params:sessions", "params:patient_id", "params:current_experiment", "params:current_calibration"],
2121
outputs="calibration_hdf5",
22-
name="generate_calibration_dat_to_hdf5_node"
22+
name="convert_calibration_dat_to_hdf5_node"
2323
)
2424
])

src/decoding_pipeline/pipelines/data_processing/nodes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ def prefix_single_channel_info(channels, patient_id, grid_split):
9494
'ch_suffix_order': ch_suffix_order
9595
}
9696

97-
def extract_bci_data(h5_data, selected_channels, electrode_labels, states, patient_id, gain):
97+
def extract_bci_data(h5_data, selected_channels, electrode_labels, states, patient_id, gain, current_experiment):
9898
eeglabels = electrode_labels['eeglabels']
9999
auxlabels = electrode_labels['auxlabels']
100100

101101
ch_include = selected_channels['ch_include']
102102
ch_exclude = selected_channels['ch_exclude']
103103

104-
selected_states = states[patient_id]['center_out']
104+
selected_states = states[patient_id][current_experiment]
105105

106106
num_channels = len(eeglabels)
107107

@@ -166,8 +166,8 @@ def extract_bci_data(h5_data, selected_channels, electrode_labels, states, patie
166166

167167
return save_dict
168168

169-
def plot_bci_states(partitioned_data, states, patient_id):
170-
state_names = states[patient_id]['center_out']
169+
def plot_bci_states(partitioned_data, states, patient_id, current_experiment):
170+
state_names = states[patient_id][current_experiment]
171171

172172
save_dict = {}
173173
for partition_key, partition_load_func in partitioned_data.items():

src/decoding_pipeline/pipelines/data_processing/pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,26 @@ def create_pipeline(**kwargs) -> Pipeline:
4949
),
5050
node(
5151
func=extract_bci_data,
52-
inputs=["center_out_hdf5", "selected_channels", "electrode_labels", "params:bci_states", "params:patient_id", "params:gain"],
52+
inputs=["center_out_hdf5", "selected_channels", "electrode_labels", "params:bci_states", "params:patient_id", "params:gain", "params:current_experiment"],
5353
outputs="center_out_extracted_pkl",
5454
name="extract_bci_data_node"
5555
),
5656
node(
5757
func=extract_bci_data,
58-
inputs=["calibration_hdf5", "selected_channels", "electrode_labels", "params:bci_states", "params:patient_id", "params:gain"],
58+
inputs=["calibration_hdf5", "selected_channels", "electrode_labels", "params:bci_states", "params:patient_id", "params:gain", "params:current_calibration"],
5959
outputs="calibration_extracted_pkl",
6060
name="extract_calibration_data_node"
6161
),
6262
],
6363
namespace="data_extraction",
6464
inputs=set(["calibration_hdf5", "center_out_hdf5", "selected_channels"]),
6565
outputs=set(["center_out_extracted_pkl", "calibration_extracted_pkl"]),
66-
parameters={"params:patient_id": "params:patient_id", "params:gain": "params:gain", "params:bci_states": "params:bci_states"})
66+
parameters={"params:patient_id": "params:patient_id", "params:gain": "params:gain", "params:bci_states": "params:bci_states", "params:current_experiment": "params:current_experiment", "params:current_calibration": "params:current_calibration"})
6767

6868
dataset_metrics_pipeline = pipeline([
6969
node(
7070
func=plot_bci_states,
71-
inputs=["center_out_extracted_pkl", "params:bci_states", "params:patient_id"],
71+
inputs=["center_out_extracted_pkl", "params:bci_states", "params:patient_id", "params:current_experiment"],
7272
outputs="state_plots",
7373
name="plot_bci_states_node"
7474

@@ -77,7 +77,7 @@ def create_pipeline(**kwargs) -> Pipeline:
7777
namespace="dataset_metrics",
7878
inputs=set(["center_out_extracted_pkl"]),
7979
outputs="state_plots",
80-
parameters={"params:patient_id": "params:patient_id", "params:bci_states": "params:bci_states"})
80+
parameters={"params:patient_id": "params:patient_id", "params:bci_states": "params:bci_states", "params:current_experiment": "params:current_experiment"})
8181

8282
# return channel_labelling_pipeline + data_extraction_pipeline
8383

@@ -86,5 +86,5 @@ def create_pipeline(**kwargs) -> Pipeline:
8686
namespace="data_preprocessing",
8787
inputs=set(["calibration_hdf5", "center_out_hdf5"]),
8888
outputs={"prefixed_channels": "prefixed_channels", "center_out_extracted_pkl": "center_out_extracted_pkl", "calibration_extracted_pkl": "calibration_extracted_pkl", "state_plots": "state_plots", "selected_channels": "selected_channels"},
89-
parameters={"params:patient_id": "params:patient_id", "params:gain": "params:gain", "params:bci_states": "params:bci_states"}
89+
parameters={"params:patient_id": "params:patient_id", "params:gain": "params:gain", "params:bci_states": "params:bci_states", "params:current_experiment": "params:current_experiment", "params:current_calibration": "params:current_calibration"}
9090
)

src/decoding_pipeline/pipelines/feature_generation/nodes.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
generated using Kedro 0.18.3
44
"""
55
import copy
6+
from turtle import update
67
import numpy as np
78
import scipy.signal as signal
89

@@ -229,12 +230,14 @@ def downsample_data_to_spectrogram(partitioned_sxx_data, partitioned_signal_data
229230
save_dict[partition_sxx_key] = create_closure_func(_downsample_data_to_spectrogram, sxx_data_func, signal_data_func)
230231
return save_dict
231232

232-
def plot_downsampled_signals(partitioned_sxx_data, partitioned_signal_data):
233+
def plot_downsampled_signals(partitioned_sxx_data, partitioned_signal_data, partitioned_sxx_std_data):
233234
save_dict = {}
234235
for partition_sxx_key, sxx_data_func in partitioned_sxx_data.items():
235236
signal_data_func = partitioned_signal_data[partition_sxx_key]
237+
sxx_std_func = partitioned_sxx_std_data[partition_sxx_key]
236238

237239
sxx_data_dict = sxx_data_func()
240+
sxx_std_data_dict = sxx_std_func()
238241
signal_data_dict = signal_data_func()
239242

240243
stimuli = signal_data_dict['stimuli']
@@ -245,7 +248,9 @@ def plot_downsampled_signals(partitioned_sxx_data, partitioned_signal_data):
245248
t = sxx_data_dict['t']
246249
sxx = sxx_data_dict['sxx']
247250

248-
fig, (ax, ax1, ax2) = plt.subplots(3, figsize=(20,10))
251+
sxx_std = sxx_std_data_dict['sxx']
252+
253+
fig, (ax, ax1, ax2, ax3) = plt.subplots(4, figsize=(20,10))
249254

250255
ax.plot(t_seconds, stimuli[:, 0], color='k', linewidth=1)
251256
ax.margins(x=0)
@@ -268,11 +273,117 @@ def plot_downsampled_signals(partitioned_sxx_data, partitioned_signal_data):
268273

269274
ax2.set_ylabel('Frequency (Hz)')
270275
ax2.set_ylim([0, 140])
271-
ax2.set_xlabel('Time (s)')
272276

273-
ax.set_title('Downsampled States, Signals and Spectrogram')
277+
ax3.pcolormesh(
278+
t,
279+
f,
280+
sxx_std[:,:,0],
281+
# norm=mpl.colors.PowerNorm(gamma=1.0 / 5),
282+
cmap="seismic",
283+
vmin=-3,
284+
vmax=3
285+
# cmap="YlGnBu"
286+
)
287+
288+
ax3.set_ylabel('Frequency (Hz)')
289+
ax3.set_ylim([0, 140])
290+
291+
292+
ax3.set_xlabel('Time (s)')
293+
294+
ax.set_title('Downsampled States, Signals, Raw Spectrogram and Standardized Spectrogram')
274295

275296
save_dict[f"{partition_sxx_key}.png"] = fig
276297

277298
plt.close()
278-
return save_dict
299+
return save_dict
300+
301+
def extract_calibration_statistics(partitioned_calibration_sxx, partitioned_calibration_data, selected_sessions, patient_id):
302+
calibration_dict = selected_sessions[patient_id]['calibration']
303+
304+
# Find all dates that have no session data. By default, all sessions will be included for calibration
305+
updated_dict = {}
306+
for partition_key in list(partitioned_calibration_sxx.keys()):
307+
# TODO: Switch the array based splitting to regex based splitting
308+
date = partition_key.split('_')[-2]
309+
session = partition_key.split('_')[-1]
310+
if date in list(calibration_dict.keys()):
311+
continue
312+
else:
313+
sessions_list = updated_dict.setdefault(date, [])
314+
sessions_list.append(session)
315+
updated_dict[date] = sessions_list
316+
317+
calibration_dict.update(updated_dict)
318+
319+
stimuli = None
320+
save_dict = {}
321+
for date_key, sessions_list in calibration_dict.items():
322+
intermed_list = []
323+
for partition_sxx_key, sxx_data_func in partitioned_calibration_sxx.items():
324+
continue_loop = True
325+
if date_key in partition_sxx_key and partition_sxx_key.split('_')[-1] in sessions_list:
326+
all_stimuli=partitioned_calibration_data[partition_sxx_key]()['stimuli']
327+
stimuli=all_stimuli[:, 0]
328+
329+
# Check to make sure calibration stimuli is not all zeros
330+
assert 1 in stimuli, "Stimuli contains all zeros, think about possibly excluding this calibration session"
331+
332+
continue_loop = False
333+
334+
if continue_loop:
335+
continue
336+
337+
data_dict = sxx_data_func()
338+
339+
sxx = data_dict['sxx']
340+
341+
sxx = sxx[:, stimuli == 1, :]
342+
343+
mean = np.mean(sxx, axis=1)[:,np.newaxis,:]
344+
std = np.std(sxx, axis=1)[:,np.newaxis,:]
345+
sxx_len = sxx.shape[1]
346+
347+
intermed_list.append({
348+
'mean': mean,
349+
'std': std,
350+
'count': sxx_len
351+
})
352+
353+
total_len = np.sum([x['count'] for x in intermed_list])
354+
fractions_len = [x['count']/total_len for x in intermed_list]
355+
356+
global_mean = sum([x['mean']*frac for x,frac in zip(intermed_list, fractions_len)])
357+
global_std = sum([x['std']*frac for x,frac in zip(intermed_list, fractions_len)])
358+
359+
save_dict[f'Calibration_statistics_{date_key}'] = {
360+
'mean': global_mean,
361+
'std': global_std
362+
}
363+
364+
return save_dict
365+
366+
def _standardize_spectrogram(sxx_func, stats_func):
367+
sxx_dict = sxx_func()
368+
stats_dict = stats_func()
369+
370+
sxx = sxx_dict['sxx']
371+
mean_sxx = stats_dict['mean']
372+
std_sxx = stats_dict['std']
373+
374+
sxx_dict['sxx'] = (sxx - mean_sxx)/std_sxx
375+
376+
return sxx_dict
377+
378+
def standardize_spectrograms(partitioned_sxx, partitioned_statistics):
379+
380+
save_dict = {}
381+
for partition_sxx_key, sxx_data_func in partitioned_sxx.items():
382+
date = partition_sxx_key.split('_')[-2]
383+
statistics_key = list(filter(lambda x: x.split('_')[-1] == date, partitioned_statistics.keys()))[0]
384+
385+
stats_func = partitioned_statistics[statistics_key]
386+
387+
save_dict[partition_sxx_key] = create_closure_func(_standardize_spectrogram, sxx_data_func, stats_func)
388+
389+
return save_dict

0 commit comments

Comments
 (0)