Skip to content

Commit f1cd293

Browse files
authored
Merge pull request #2420 from DradeAW/faster_gaussian_filter
Faster Gaussian filter implementation
2 parents 43cee8a + 0ecf712 commit f1cd293

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

src/spikeinterface/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from .core_tools import (
8989
read_python,
9090
write_python,
91+
normal_pdf,
9192
)
9293
from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs
9394
from .recording_tools import (

src/spikeinterface/core/core_tools.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,26 @@ def is_editable_mode() -> bool:
453453
import spikeinterface
454454

455455
return (Path(spikeinterface.__file__).parents[2] / "README.md").exists()
456+
457+
458+
def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0):
459+
"""
460+
Manual implementation of the Normal distribution pdf (probability density function).
461+
It is about 8 to 10 times faster than scipy.stats.norm.pdf().
462+
463+
Parameters
464+
----------
465+
x: scalar or array
466+
The x-axis
467+
mu: float, default: 0.0
468+
The mean of the Normal distribution.
469+
sigma: float, default: 1.0
470+
The standard deviation of the Normal distribution.
471+
472+
Returns
473+
-------
474+
normal_pdf: scalar or array (same type as 'x')
475+
The pdf of the Normal distribution for the given x-axis.
476+
"""
477+
478+
return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2))

src/spikeinterface/core/tests/test_core_tools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import platform
2+
import math
23
from multiprocessing.shared_memory import SharedMemory
34
from pathlib import Path
45
import importlib
@@ -10,6 +11,7 @@
1011
make_paths_relative,
1112
make_paths_absolute,
1213
check_paths_relative,
14+
normal_pdf,
1315
)
1416
from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor
1517
from spikeinterface.core.generate import NoiseGeneratorRecording
@@ -87,5 +89,21 @@ def test_path_utils_functions():
8789
assert check_paths_relative(d, r"\\host\share")
8890

8991

92+
def test_normal_pdf() -> None:
93+
mu = 4.160771
94+
sigma = 2.9334
95+
dx = 0.001
96+
97+
xaxis = np.arange(-15, 25, dx)
98+
gauss = normal_pdf(xaxis, mu=mu, sigma=sigma)
99+
100+
assert math.isclose(1.0, dx * np.sum(gauss)) # Checking that sum of pdf is 1
101+
assert math.isclose(mu, dx * np.sum(xaxis * gauss)) # Checking that mean is mu
102+
assert math.isclose(sigma**2, dx * np.sum(xaxis**2 * gauss) - mu**2) # Checking that variance is sigma^2
103+
104+
print(normal_pdf(-0.9355, mu=mu, sigma=sigma))
105+
assert math.isclose(normal_pdf(-0.9355, mu=mu, sigma=sigma), 0.03006929091)
106+
107+
90108
if __name__ == "__main__":
91109
test_path_utils_functions()

src/spikeinterface/preprocessing/filter_gaussian.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from spikeinterface.core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin
5+
from spikeinterface.core import BaseRecording, BaseRecordingSegment, get_chunk_with_margin, normal_pdf
66
from spikeinterface.core.core_tools import define_function_from_class
77
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
88

@@ -95,16 +95,14 @@ def _create_gaussian(self, N: int, cutoff_f: float):
9595
sf = self.parent_recording_segment.sampling_frequency
9696
faxis = np.fft.fftfreq(N, d=1 / sf)
9797

98-
from scipy.stats import norm
99-
10098
if cutoff_f > sf / 8: # The Fourier transform of a Gaussian with a very low sigma isn't a Gaussian.
10199
sigma = sf / (2 * np.pi * cutoff_f)
102-
limit = int(round(6 * sigma)) + 1
100+
limit = int(round(5 * sigma)) + 1
103101
xaxis = np.arange(-limit, limit + 1) / sigma
104-
gaussian = norm.pdf(xaxis) / sigma
102+
gaussian = normal_pdf(xaxis) / sigma
105103
gaussian = np.abs(np.fft.fft(gaussian, n=N))
106104
else:
107-
gaussian = norm.pdf(faxis / cutoff_f) * np.sqrt(2 * np.pi)
105+
gaussian = normal_pdf(faxis / cutoff_f) * np.sqrt(2 * np.pi)
108106

109107
if cutoff_f not in self.cached_gaussian:
110108
self.cached_gaussian[cutoff_f] = dict()

0 commit comments

Comments
 (0)