Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions frouros/data_drift/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Data drift batch detection methods init."""

from .distance_based import (
Bhattacharyya,
EMD,
Hellinger,
HistogramIntersection,
Expand All @@ -17,6 +18,7 @@
)

__all__ = [
"Bhattacharyya",
"ChiSquareTest",
"CVMTest",
"EMD",
Expand Down
2 changes: 2 additions & 0 deletions frouros/data_drift/batch/distance_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Data drift batch distance based detection methods' init."""

from .bhattacharyya import Bhattacharyya
from .emd import EMD
from .hellinger import Hellinger
from .histogram_intersection import HistogramIntersection
Expand All @@ -9,6 +10,7 @@
from .mmd import MMD

__all__ = [
"Bhattacharyya",
"EMD",
"Hellinger",
"HistogramIntersection",
Expand Down
27 changes: 27 additions & 0 deletions frouros/data_drift/batch/distance_based/bhattacharyya.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Bhattacharyya distance module."""

import numpy as np # type: ignore

from frouros.data_drift.batch.distance_based.base import (
DistanceBinsBasedBase,
)


class Bhattacharyya(DistanceBinsBasedBase):
"""Bhattacharyya algorithm class."""

def _distance_measure_bins(
self,
X_ref_: np.ndarray, # noqa: N803
X: np.ndarray, # noqa: N803
) -> float:
distance = self._bhattacharyya(
X_ref_=X_ref_,
X=X,
)
return distance

@staticmethod
def _bhattacharyya(X_ref_: np.ndarray, X: np.ndarray) -> float: # noqa: N803
distance = 1 - np.sum(np.sqrt(X_ref_ * X))
return distance
5 changes: 3 additions & 2 deletions frouros/tests/test_data_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from frouros.data_drift.batch.base import DataDriftBatchBase
from frouros.data_drift.batch.distance_based import (
Bhattacharyya,
EMD,
Hellinger,
HistogramIntersection,
Expand Down Expand Up @@ -86,7 +87,7 @@ def test_batch_distance_based_univariate(

@pytest.mark.parametrize(
"detector, expected_distance",
[(PSI(), 468.79410784), (Hellinger(), 0.77137797)],
[(PSI(), 468.79410784), (Hellinger(), 0.77137797), (Bhattacharyya(), 0.59502397)],
)
def test_batch_distance_bins_based_univariate_different_distribution(
univariate_distribution_p: Tuple[float, float],
Expand Down Expand Up @@ -120,7 +121,7 @@ def test_batch_distance_bins_based_univariate_different_distribution(

@pytest.mark.parametrize(
"detector, expected_distance",
[(PSI(), 0.01840072), (Hellinger(), 0.04792538)],
[(PSI(), 0.01840072), (Hellinger(), 0.04792538), (Bhattacharyya(), 0.00229684)],
)
def test_batch_distance_bins_based_univariate_same_distribution(
univariate_distribution_p: Tuple[float, float],
Expand Down