Skip to content

Fix _mmd from MMD to be an static method #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 50 additions & 62 deletions frouros/detectors/data_drift/batch/distance_based/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def __init__(
)
self.kernel = kernel
self.chunk_size = chunk_size
self._chunk_size_x = None
self.X_chunks_combinations = None
self.X_num_samples = None
self.expected_k_xx = None

@property
def chunk_size(self) -> Optional[int]:
Expand Down Expand Up @@ -108,55 +104,16 @@ def _distance_measure(
X: np.ndarray, # noqa: N803
**kwargs,
) -> DistanceResult:
mmd = self._mmd(X=X_ref, Y=X, kernel=self.kernel, **kwargs)
mmd = self._mmd(
X=X_ref,
Y=X,
kernel=self.kernel,
chunk_size=self.chunk_size,
**kwargs,
)
distance_test = DistanceResult(distance=mmd)
return distance_test

def _fit(
self,
X: np.ndarray, # noqa: N803
**kwargs,
) -> None:
super()._fit(X=X)
# Add dimension only for the kernel calculation (if dim == 1)
if X.ndim == 1:
X = np.expand_dims(X, axis=1) # noqa: N806
self.X_num_samples = len(self.X_ref) # type: ignore # noqa: N806

self._chunk_size_x = (
self.X_num_samples
if self.chunk_size is None
else self.chunk_size # type: ignore
)

X_chunks = self._get_chunks( # noqa: N806
data=X,
chunk_size=self._chunk_size_x, # type: ignore
)
xx_chunks_combinations = itertools.product(X_chunks, repeat=2) # noqa: N806

if kwargs.get("verbose", False):
num_chunks = (
math.ceil(self.X_num_samples / self._chunk_size_x) ** 2 # type: ignore
)
xx_chunks_combinations = tqdm.tqdm(
xx_chunks_combinations,
total=num_chunks,
)

k_xx_sum = (
self._compute_kernel(
chunk_combinations=xx_chunks_combinations, # type: ignore
kernel=self.kernel,
)
# Remove diagonal (j!=i case)
- self.X_num_samples # type: ignore
)

self.expected_k_xx = k_xx_sum / ( # type: ignore
self.X_num_samples * (self.X_num_samples - 1) # type: ignore
)

@staticmethod
def _compute_kernel(chunk_combinations: Generator, kernel: Callable) -> float:
k_sum = np.array([kernel(*chunk).sum() for chunk in chunk_combinations]).sum()
Expand All @@ -170,8 +127,8 @@ def _get_chunks(data: np.ndarray, chunk_size: int) -> Generator:
)
return chunks

@staticmethod
def _mmd( # pylint: disable=too-many-locals
self,
X: np.ndarray, # noqa: N803
Y: np.ndarray,
*,
Expand All @@ -183,33 +140,56 @@ def _mmd( # pylint: disable=too-many-locals
X = np.expand_dims(X, axis=1) # noqa: N806
Y = np.expand_dims(Y, axis=1) # noqa: N806

X_chunks = self._get_chunks( # noqa: N806
data=X,
chunk_size=self._chunk_size_x, # type: ignore
x_num_samples = len(X) # noqa: N806
chunk_size_x = (
kwargs["chunk_size"]
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
else x_num_samples
)
x_chunks, x_chunks_copy = itertools.tee( # noqa: N806
MMD._get_chunks(
data=X,
chunk_size=chunk_size_x, # type: ignore
),
2,
)
y_num_samples = len(Y) # noqa: N806
chunk_size_y = y_num_samples if self.chunk_size is None else self.chunk_size
chunk_size_y = (
kwargs["chunk_size"]
if "chunk_size" in kwargs and kwargs["chunk_size"] is not None
else y_num_samples
)
y_chunks, y_chunks_copy = itertools.tee( # noqa: N806
self._get_chunks(
MMD._get_chunks(
data=Y,
chunk_size=chunk_size_y, # type: ignore
),
2,
)
x_chunks_combinations = itertools.product( # noqa: N806
x_chunks,
repeat=2,
)
y_chunks_combinations = itertools.product( # noqa: N806
y_chunks,
repeat=2,
)
xy_chunks_combinations = itertools.product( # noqa: N806
X_chunks,
x_chunks_copy,
y_chunks_copy,
)

if kwargs.get("verbose", False):
num_chunks_x = math.ceil(x_num_samples / chunk_size_x) # type: ignore
num_chunks_y = math.ceil(y_num_samples / chunk_size_y) # type: ignore
num_chunks_x_combinations = num_chunks_x**2
num_chunks_y_combinations = num_chunks_y**2
num_chunks_xy = (
math.ceil(len(X) / self._chunk_size_x) * num_chunks_y # type: ignore
math.ceil(len(X) / chunk_size_x) * num_chunks_y # type: ignore
)
x_chunks_combinations = tqdm.tqdm(
x_chunks_combinations,
total=num_chunks_x_combinations,
)
y_chunks_combinations = tqdm.tqdm(
y_chunks_combinations,
Expand All @@ -220,21 +200,29 @@ def _mmd( # pylint: disable=too-many-locals
total=num_chunks_xy,
)

k_xx_sum = (
MMD._compute_kernel(
chunk_combinations=x_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- x_num_samples # type: ignore
)
k_yy_sum = (
self._compute_kernel(
MMD._compute_kernel(
chunk_combinations=y_chunks_combinations, # type: ignore
kernel=kernel,
)
# Remove diagonal (j!=i case)
- y_num_samples # type: ignore
)
k_xy_sum = self._compute_kernel(
k_xy_sum = MMD._compute_kernel(
chunk_combinations=xy_chunks_combinations, # type: ignore
kernel=kernel,
)
mmd = (
self.expected_k_xx # type: ignore
+k_xx_sum / (x_num_samples * (x_num_samples - 1))
+ k_yy_sum / (y_num_samples * (y_num_samples - 1))
- 2 * k_xy_sum / (self.X_num_samples * y_num_samples) # type: ignore
- 2 * k_xy_sum / (x_num_samples * y_num_samples) # type: ignore
)
return mmd