@@ -52,7 +52,63 @@ def generate_session_displacement_recordings(
52
52
extra_outputs = False ,
53
53
seed = None ,
54
54
):
55
- """ """
55
+ """
56
+ Generate a set of recordings simulating probe drift across recording
57
+ sessions.
58
+
59
+ Rigid drift can be added in the (x, y) direction in `recording_shifts`.
60
+ These drifts can be made non-rigid (scaled dependent on the unit location)
61
+ with the `non_rigid_gradient` parameter. Amplitude of units can be scaled
62
+ (e.g. template signal removed by scaling with zero) by specifying scaling
63
+ factors in `recording_amplitude_scalings`.
64
+
65
+ Parameters
66
+ ----------
67
+
68
+ num_units : int
69
+ The number of units in the generated recordings.
70
+ recording_durations : list
71
+ An array of length (num_recordings,) specifying the
72
+ duration that each created recording should be.
73
+ recording_shifts : list
74
+ An array of length (num_recordings,) in which each element
75
+ is a 2-element array specifying the (x, y) shift for the recording.
76
+ Typically, the first recording will have shift (0, 0) so all further
77
+ recordings are shifted relative to it. e.g. to create two recordings,
78
+ the second shifted by 50 um in the x-direction and 250 um in the y
79
+ direction : ((0, 0), (50, 250)).
80
+ non_rigid_gradient : float
81
+ Factor which sets the level of non-rigidty in the displacement.
82
+ See `calculate_displacement_unit_factor` for details.
83
+ recording_amplitude_scalings : dict
84
+ A dict with keys:
85
+ "method" - order by which to apply the scalings.
86
+ "by_passed_order" - scalings are applied to the unit templates
87
+ in order passed
88
+ "by_firing_rate" - scalings are applied to the units in order of
89
+ maximum to minimum firing rate
90
+ "by_amplitude_and_firing_rate" - scalings are applied to the units
91
+ in order of amplitude * firing_rate (maximum to minimum)
92
+ "scalings" - a list of numpy arrays, one for each recording, with
93
+ each entry an array of length num_units holding the unit scalings.
94
+ e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
95
+
96
+ All other parameters are used as in from `generate_drifting_recording()`.
97
+
98
+ Returns
99
+ -------
100
+ output_recordings : list
101
+ A list of recordings with units shifted (i.e. replicated probe shift).
102
+ output_sortings : list
103
+ A list of corresponding sorting objects.
104
+ extra_outputs_dict (options) : dict
105
+ When `extra_outputs` is `True`, a dict containing variables used
106
+ in the generation process.
107
+ "unit_locations" : A list (length num records) of shifted unit locations
108
+ "templates_array_moved" : list[np.array]
109
+ A list (length num records) of (num_units, num_samples, num_channels)
110
+ arrays of templates that have been shifted.
111
+ """
56
112
_check_generate_session_displacement_arguments (
57
113
num_units , recording_durations , recording_shifts , recording_amplitude_scalings
58
114
)
@@ -82,7 +138,7 @@ def generate_session_displacement_recordings(
82
138
83
139
for rec_idx , (shift , duration ) in enumerate (zip (recording_shifts , recording_durations )):
84
140
85
- displacement_vector , displacement_unit_factor = get_inter_session_displacements (
141
+ displacement_vector , displacement_unit_factor = _get_inter_session_displacements (
86
142
shift ,
87
143
non_rigid_gradient ,
88
144
num_units ,
@@ -114,7 +170,7 @@ def generate_session_displacement_recordings(
114
170
)
115
171
116
172
# Generate the (possibly shifted, scaled) unit templates
117
- templates_moved_array = generate_templates (
173
+ template_array_moved = generate_templates (
118
174
channel_locations ,
119
175
unit_locations_moved ,
120
176
sampling_frequency = sampling_frequency ,
@@ -124,8 +180,8 @@ def generate_session_displacement_recordings(
124
180
125
181
if recording_amplitude_scalings is not None :
126
182
127
- templates_moved_array = amplitude_scale_templates_in_place (
128
- templates_moved_array , recording_amplitude_scalings , sorting_extra_outputs , rec_idx
183
+ template_array_moved = _amplitude_scale_templates_in_place (
184
+ template_array_moved , recording_amplitude_scalings , sorting_extra_outputs , rec_idx
129
185
)
130
186
131
187
# Bring it all together in a `InjectTemplatesRecording` and
@@ -135,7 +191,7 @@ def generate_session_displacement_recordings(
135
191
136
192
recording = InjectTemplatesRecording (
137
193
sorting = sorting ,
138
- templates = templates_moved_array ,
194
+ templates = template_array_moved ,
139
195
nbefore = nbefore ,
140
196
amplitude_factor = None ,
141
197
parent_recording = noise ,
@@ -152,19 +208,46 @@ def generate_session_displacement_recordings(
152
208
output_recordings .append (recording )
153
209
output_sortings .append (sorting )
154
210
extra_outputs_dict ["unit_locations" ].append (unit_locations_moved )
155
- extra_outputs_dict ["template_array_moved" ].append (templates_moved_array )
211
+ extra_outputs_dict ["template_array_moved" ].append (template_array_moved )
156
212
157
213
if extra_outputs :
158
214
return output_recordings , output_sortings , extra_outputs_dict
159
215
else :
160
216
return output_recordings , output_sortings
161
217
162
218
163
- def get_inter_session_displacements (shift , non_rigid_gradient , num_units , unit_locations ):
164
- """ """
219
+ def _get_inter_session_displacements (shift , non_rigid_gradient , num_units , unit_locations ):
220
+ """
221
+ Get the formatted `displacement_vector` and `displacement_unit_factor`
222
+ used to shift the `unit_locations`..
223
+
224
+ Parameters
225
+ ---------
226
+ shift : np.array | list | tuple
227
+ A 2-element array with the shift in the (x, y) direction.
228
+ non_rigid_gradient : float
229
+ Factor which sets the level of non-rigidty in the displacement.
230
+ See `calculate_displacement_unit_factor` for details.
231
+ num_units : int
232
+ Number of units
233
+ unit_locations : np.array
234
+ (num_units, 3) array of unit locations (x, y, z).
235
+
236
+ Returns
237
+ -------
238
+ displacement_vector : np.array
239
+ A (:, 2) array of (x, y) of displacements
240
+ to add to (i.e. move) unit_locations.
241
+ e.g. array([[1, 2]])
242
+ displacement_unit_factor : np.array
243
+ A (num_units, :) array of scaling values to apply to the
244
+ displacement vector in order to add nonrigid shift to
245
+ the displacement. Note the same scaling is applied to the
246
+ x and y dimension.
247
+ """
165
248
displacement_vector = np .atleast_2d (shift )
166
249
167
- if non_rigid_gradient is None or shift == ( 0 , 0 ):
250
+ if non_rigid_gradient is None or ( shift [ 0 ] == 0 and shift [ 1 ] == 0 ):
168
251
displacement_unit_factor = np .ones ((num_units , 1 ))
169
252
else :
170
253
displacement_unit_factor = calculate_displacement_unit_factor (
@@ -178,8 +261,38 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l
178
261
return displacement_vector , displacement_unit_factor
179
262
180
263
181
- def amplitude_scale_templates_in_place (templates_array , recording_amplitude_scalings , sorting_extra_outputs , rec_idx ):
182
- """ """
264
+ def _amplitude_scale_templates_in_place (templates_array , recording_amplitude_scalings , sorting_extra_outputs , rec_idx ):
265
+ """
266
+ Scale a set of templates given a set of scaling values. The scaling
267
+ values can be applied in the order passed, or instead in order of
268
+ the unit firing range (max to min) or unit amplitude * firing rate (max to min).
269
+ This will chang the `templates_array` in place.
270
+
271
+ Parameters
272
+ ----------
273
+ templates_array : np.array
274
+ A (num_units, num_samples, num_channels) array of
275
+ template waveforms for all units.
276
+ recording_amplitude_scalings : dict
277
+ see `generate_session_displacement_recordings()`.
278
+ sorting_extra_outputs : dict
279
+ Extra output of `generate_sorting` holding the firing frequency of all units.
280
+ The unit order is assumed to match the templates.
281
+ rec_idx : int
282
+ The index of the recording for which the templates are being scaled.
283
+
284
+ Notes
285
+ -----
286
+ This method is used in the context of inter-session displacement. Often,
287
+ units may drop out of the recording across sessions. This simulates this by
288
+ directly scaling the template (e.g. if scaling by 0, the template is completely
289
+ dropped out). The provided scalings can be applied in the order passed, or
290
+ in the order of unit firing rate or firing rate * amplitude. The idea is
291
+ that it may be desired to remove to downscale the most activate neurons
292
+ that contribute most significantly to activity histograms. Similarly,
293
+ if amplitude is included in activity histograms the amplitude may
294
+ also want to be considered when ordering the units to downscale.
295
+ """
183
296
method = recording_amplitude_scalings ["method" ]
184
297
185
298
if method in ["by_amplitude_and_firing_rate" , "by_firing_rate" ]:
0 commit comments