Skip to content

Fix rbf kernel #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions docs/source/examples/data_drift/MMD_simple.ipynb

Large diffs are not rendered by default.

19 changes: 1 addition & 18 deletions frouros/detectors/data_drift/batch/distance_based/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,14 @@

import numpy as np # type: ignore
import tqdm # type: ignore
from scipy.spatial.distance import cdist # type: ignore

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import MultivariateData
from frouros.detectors.data_drift.batch.distance_based.base import (
BaseDistanceBased,
DistanceResult,
)


def rbf_kernel(
X: np.ndarray, Y: np.ndarray, std: float = 1.0 # noqa: N803
) -> np.ndarray:
"""Radial basis function kernel between X and Y matrices.

:param X: X matrix
:type X: numpy.ndarray
:param Y: Y matrix
:type Y: numpy.ndarray
:param std: standard deviation value
:type std: float
:return: Radial basis kernel matrix
:rtype: numpy.ndarray
"""
return np.exp(-cdist(X, Y, "sqeuclidean") / 2 * std**2)
from frouros.utils.kernels import rbf_kernel


class MMD(BaseDistanceBased):
Expand Down
103 changes: 103 additions & 0 deletions frouros/tests/unit/utils/test_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Test kernels module."""

import numpy as np # type: ignore
import pytest # type: ignore

from frouros.utils.kernels import rbf_kernel


# TODO: Create fixtures for the matrices and the expected kernel values


@pytest.mark.parametrize(
"X, Y, sigma, expected_kernel_value",
[
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 0.5, np.array([[1.0]])),
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 1.0, np.array([[1.0]])),
(np.array([[1, 2, 3]]), np.array([[1, 2, 3]]), 2.0, np.array([[1.0]])),
(
np.array([[1, 2, 3]]),
np.array([[4, 5, 6]]),
0.5,
np.array([[3.53262857e-24]]),
),
(
np.array([[1, 2, 3]]),
np.array([[4, 5, 6]]),
1.0,
np.array([[1.37095909e-06]]),
),
(np.array([[1, 2, 3]]), np.array([[4, 5, 6]]), 2.0, np.array([[0.03421812]])),
(
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1, 2, 3], [4, 5, 6]]),
0.5,
np.array(
[[1.00000000e00, 3.53262857e-24], [3.53262857e-24, 1.00000000e00]]
),
),
(
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1, 2, 3], [4, 5, 6]]),
1.0,
np.array(
[[1.00000000e00, 1.37095909e-06], [1.37095909e-06, 1.00000000e00]]
),
),
(
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1, 2, 3], [4, 5, 6]]),
2.0,
np.array([[1.00000000e00, 0.03421812], [0.03421812, 1.00000000e00]]),
),
(
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
0.5,
np.array(
[[2.23130160e-01, 1.20048180e-32], [5.17555501e-17, 2.23130160e-01]]
),
),
(
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
1.0,
np.array(
[[6.87289279e-01, 1.04674018e-08], [8.48182352e-05, 6.87289279e-01]]
),
),
(
np.array([[1, 2, 3], [4, 5, 6]]),
np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]),
2.0,
np.array([[0.91051036, 0.01011486], [0.09596709, 0.91051036]]),
),
],
)
def test_rbf_kernel(
X: np.ndarray, # noqa: N803
Y: np.ndarray,
sigma: float,
expected_kernel_value: np.ndarray,
) -> None:
"""Test rbf kernel.

:param X: X values
:type X: numpy.ndarray
:param Y: Y values
:type Y: numpy.ndarray
:param sigma: sigma value
:type sigma: float
:param expected_kernel_value: expected kernel value
:type expected_kernel_value: numpy.ndarray
"""
assert np.all(
np.isclose(
rbf_kernel(
X=X,
Y=Y,
sigma=sigma,
),
expected_kernel_value,
),
)
21 changes: 21 additions & 0 deletions frouros/utils/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Kernels module."""

import numpy as np # type: ignore
from scipy.spatial.distance import cdist # type: ignore


def rbf_kernel(
X: np.ndarray, Y: np.ndarray, sigma: float = 1.0 # noqa: N803
) -> np.ndarray:
"""Radial basis function kernel between X and Y matrices.

:param X: X matrix
:type X: numpy.ndarray
:param Y: Y matrix
:type Y: numpy.ndarray
:param sigma: sigma value (equivalent to gamma = 1 / (2 * sigma**2))
:type sigma: float
:return: Radial basis kernel matrix
:rtype: numpy.ndarray
"""
return np.exp(-cdist(X, Y, "sqeuclidean") / (2 * sigma**2))