Skip to content

Commit 9798c75

Browse files
committed
Add documentation.
1 parent 3337037 commit 9798c75

File tree

1 file changed

+125
-12
lines changed

1 file changed

+125
-12
lines changed

src/spikeinterface/generation/session_displacement_generator.py

Lines changed: 125 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,63 @@ def generate_session_displacement_recordings(
5252
extra_outputs=False,
5353
seed=None,
5454
):
55-
""" """
55+
"""
56+
Generate a set of recordings simulating probe drift across recording
57+
sessions.
58+
59+
Rigid drift can be added in the (x, y) direction in `recording_shifts`.
60+
These drifts can be made non-rigid (scaled dependent on the unit location)
61+
with the `non_rigid_gradient` parameter. Amplitude of units can be scaled
62+
(e.g. template signal removed by scaling with zero) by specifying scaling
63+
factors in `recording_amplitude_scalings`.
64+
65+
Parameters
66+
----------
67+
68+
num_units : int
69+
The number of units in the generated recordings.
70+
recording_durations : list
71+
An array of length (num_recordings,) specifying the
72+
duration that each created recording should be.
73+
recording_shifts : list
74+
An array of length (num_recordings,) in which each element
75+
is a 2-element array specifying the (x, y) shift for the recording.
76+
Typically, the first recording will have shift (0, 0) so all further
77+
recordings are shifted relative to it. e.g. to create two recordings,
78+
the second shifted by 50 um in the x-direction and 250 um in the y
79+
direction : ((0, 0), (50, 250)).
80+
non_rigid_gradient : float
81+
Factor which sets the level of non-rigidty in the displacement.
82+
See `calculate_displacement_unit_factor` for details.
83+
recording_amplitude_scalings : dict
84+
A dict with keys:
85+
"method" - order by which to apply the scalings.
86+
"by_passed_order" - scalings are applied to the unit templates
87+
in order passed
88+
"by_firing_rate" - scalings are applied to the units in order of
89+
maximum to minimum firing rate
90+
"by_amplitude_and_firing_rate" - scalings are applied to the units
91+
in order of amplitude * firing_rate (maximum to minimum)
92+
"scalings" - a list of numpy arrays, one for each recording, with
93+
each entry an array of length num_units holding the unit scalings.
94+
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
95+
96+
All other parameters are used as in from `generate_drifting_recording()`.
97+
98+
Returns
99+
-------
100+
output_recordings : list
101+
A list of recordings with units shifted (i.e. replicated probe shift).
102+
output_sortings : list
103+
A list of corresponding sorting objects.
104+
extra_outputs_dict (options) : dict
105+
When `extra_outputs` is `True`, a dict containing variables used
106+
in the generation process.
107+
"unit_locations" : A list (length num records) of shifted unit locations
108+
"templates_array_moved" : list[np.array]
109+
A list (length num records) of (num_units, num_samples, num_channels)
110+
arrays of templates that have been shifted.
111+
"""
56112
_check_generate_session_displacement_arguments(
57113
num_units, recording_durations, recording_shifts, recording_amplitude_scalings
58114
)
@@ -82,7 +138,7 @@ def generate_session_displacement_recordings(
82138

83139
for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)):
84140

85-
displacement_vector, displacement_unit_factor = get_inter_session_displacements(
141+
displacement_vector, displacement_unit_factor = _get_inter_session_displacements(
86142
shift,
87143
non_rigid_gradient,
88144
num_units,
@@ -114,7 +170,7 @@ def generate_session_displacement_recordings(
114170
)
115171

116172
# Generate the (possibly shifted, scaled) unit templates
117-
templates_moved_array = generate_templates(
173+
template_array_moved = generate_templates(
118174
channel_locations,
119175
unit_locations_moved,
120176
sampling_frequency=sampling_frequency,
@@ -124,8 +180,8 @@ def generate_session_displacement_recordings(
124180

125181
if recording_amplitude_scalings is not None:
126182

127-
templates_moved_array = amplitude_scale_templates_in_place(
128-
templates_moved_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
183+
template_array_moved = _amplitude_scale_templates_in_place(
184+
template_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
129185
)
130186

131187
# Bring it all together in a `InjectTemplatesRecording` and
@@ -135,7 +191,7 @@ def generate_session_displacement_recordings(
135191

136192
recording = InjectTemplatesRecording(
137193
sorting=sorting,
138-
templates=templates_moved_array,
194+
templates=template_array_moved,
139195
nbefore=nbefore,
140196
amplitude_factor=None,
141197
parent_recording=noise,
@@ -152,19 +208,46 @@ def generate_session_displacement_recordings(
152208
output_recordings.append(recording)
153209
output_sortings.append(sorting)
154210
extra_outputs_dict["unit_locations"].append(unit_locations_moved)
155-
extra_outputs_dict["template_array_moved"].append(templates_moved_array)
211+
extra_outputs_dict["template_array_moved"].append(template_array_moved)
156212

157213
if extra_outputs:
158214
return output_recordings, output_sortings, extra_outputs_dict
159215
else:
160216
return output_recordings, output_sortings
161217

162218

163-
def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations):
164-
""" """
219+
def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations):
220+
"""
221+
Get the formatted `displacement_vector` and `displacement_unit_factor`
222+
used to shift the `unit_locations`..
223+
224+
Parameters
225+
---------
226+
shift : np.array | list | tuple
227+
A 2-element array with the shift in the (x, y) direction.
228+
non_rigid_gradient : float
229+
Factor which sets the level of non-rigidty in the displacement.
230+
See `calculate_displacement_unit_factor` for details.
231+
num_units : int
232+
Number of units
233+
unit_locations : np.array
234+
(num_units, 3) array of unit locations (x, y, z).
235+
236+
Returns
237+
-------
238+
displacement_vector : np.array
239+
A (:, 2) array of (x, y) of displacements
240+
to add to (i.e. move) unit_locations.
241+
e.g. array([[1, 2]])
242+
displacement_unit_factor : np.array
243+
A (num_units, :) array of scaling values to apply to the
244+
displacement vector in order to add nonrigid shift to
245+
the displacement. Note the same scaling is applied to the
246+
x and y dimension.
247+
"""
165248
displacement_vector = np.atleast_2d(shift)
166249

167-
if non_rigid_gradient is None or shift == (0, 0):
250+
if non_rigid_gradient is None or (shift[0] == 0 and shift[1] == 0):
168251
displacement_unit_factor = np.ones((num_units, 1))
169252
else:
170253
displacement_unit_factor = calculate_displacement_unit_factor(
@@ -178,8 +261,38 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l
178261
return displacement_vector, displacement_unit_factor
179262

180263

181-
def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
182-
""" """
264+
def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
265+
"""
266+
Scale a set of templates given a set of scaling values. The scaling
267+
values can be applied in the order passed, or instead in order of
268+
the unit firing range (max to min) or unit amplitude * firing rate (max to min).
269+
This will chang the `templates_array` in place.
270+
271+
Parameters
272+
----------
273+
templates_array : np.array
274+
A (num_units, num_samples, num_channels) array of
275+
template waveforms for all units.
276+
recording_amplitude_scalings : dict
277+
see `generate_session_displacement_recordings()`.
278+
sorting_extra_outputs : dict
279+
Extra output of `generate_sorting` holding the firing frequency of all units.
280+
The unit order is assumed to match the templates.
281+
rec_idx : int
282+
The index of the recording for which the templates are being scaled.
283+
284+
Notes
285+
-----
286+
This method is used in the context of inter-session displacement. Often,
287+
units may drop out of the recording across sessions. This simulates this by
288+
directly scaling the template (e.g. if scaling by 0, the template is completely
289+
dropped out). The provided scalings can be applied in the order passed, or
290+
in the order of unit firing rate or firing rate * amplitude. The idea is
291+
that it may be desired to remove to downscale the most activate neurons
292+
that contribute most significantly to activity histograms. Similarly,
293+
if amplitude is included in activity histograms the amplitude may
294+
also want to be considered when ordering the units to downscale.
295+
"""
183296
method = recording_amplitude_scalings["method"]
184297

185298
if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]:

0 commit comments

Comments
 (0)