Skip to content
116 changes: 116 additions & 0 deletions aeon/clustering/_kernel_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,122 @@
from typing import Optional, Union

import numpy as np
from numba import njit
from numpy.random import RandomState

from aeon.clustering.base import BaseClusterer
from aeon.distances.pointwise._squared import squared_pairwise_distance


@njit(cache=True, fastmath=True)
def _kdtw_lk(x, y, local_kernel):
channels = np.shape(x)[1]
padding_vector = np.zeros((1, channels))

x = np.concatenate((padding_vector, x), axis=0)
y = np.concatenate((padding_vector, y), axis=0)

x_timepoints, _ = np.shape(x)
y_timepoints, _ = np.shape(y)

cost_matrix = np.zeros((x_timepoints, y_timepoints))
cumulative_dp_diag = np.zeros((x_timepoints, y_timepoints))
diagonal_weights = np.zeros(max(x_timepoints, y_timepoints))

min_timepoints = min(x_timepoints, y_timepoints)
diagonal_weights[1] = 1.0
for i in range(1, min_timepoints):
diagonal_weights[i] = local_kernel[i - 1, i - 1]

cost_matrix[0, 0] = 1
cumulative_dp_diag[0, 0] = 1

for i in range(1, x_timepoints):
cost_matrix[i, 1] = cost_matrix[i - 1, 1] * local_kernel[i - 1, 2]
cumulative_dp_diag[i, 1] = cumulative_dp_diag[i - 1, 1] * diagonal_weights[i]

for j in range(1, y_timepoints):
cost_matrix[1, j] = cost_matrix[1, j - 1] * local_kernel[2, j - 1]
cumulative_dp_diag[1, j] = cumulative_dp_diag[1, j - 1] * diagonal_weights[j]

for i in range(1, x_timepoints):
for j in range(1, y_timepoints):
local_cost = local_kernel[i - 1, j - 1]
cost_matrix[i, j] = (
cost_matrix[i - 1, j]
+ cost_matrix[i, j - 1]
+ cost_matrix[i - 1, j - 1]
) * local_cost
if i == j:
cumulative_dp_diag[i, j] = (
cumulative_dp_diag[i - 1, j - 1] * local_cost
+ cumulative_dp_diag[i - 1, j] * diagonal_weights[i]
+ cumulative_dp_diag[i, j - 1] * diagonal_weights[j]
)
else:
cumulative_dp_diag[i, j] = (
cumulative_dp_diag[i - 1, j] * diagonal_weights[i]
+ cumulative_dp_diag[i, j - 1] * diagonal_weights[j]
)
cost_matrix = cost_matrix + cumulative_dp_diag
return cost_matrix[x_timepoints - 1, y_timepoints - 1]


def kdtw(x, y, sigma=1.0, epsilon=1e-3):
"""
Callable kernel function for KernelKMeans.

Parameters
----------
X: np.ndarray, of shape (n_timepoints, n_channels)
First time series sample.
y: np.ndarray, of shape (n_timepoints, n_channels)
Second time series sample.
sigma : float, default=1.0
Parameter controlling the width of the exponential local kernel. Smaller sigma
values lead to a sharper decay of similarity with increasing distance.
epsilon : float, default=1e-3
A small constant added for numerical stability to avoid zero values in the
local kernel matrix.

Returns
-------
similarity : float
A scalar value representing the computed KDTW similarity between the two time
series. Higher values indicate greater similarity.
"""
distance = squared_pairwise_distance(x, y)
local_kernel = (np.exp(-distance / sigma) + epsilon) / (3 * (1 + epsilon))
return _kdtw_lk(x, y, local_kernel)


def factory_kdtw_kernel(channels: int):
"""
Return a kdtw kernel callable function that flattened samples to (T, channels).

Parameters
----------
channels: int
Number of channels per timepoint.

Returns
-------
kdtw_kernel : callable
A callable kernel function that computes the KDTW similarity between two
time series samples. The function signature is the same as the kdtw
function.
"""

def kdtw_kernel(x, y, sigma=1.0, epsilon=1e-3):
if x.ndim == 1:
T = x.size // channels
x = x.reshape(T, channels)
if y.ndim == 1:
T = y.size // channels
y = y.reshape(T, channels)
return kdtw(x, y, sigma=sigma, epsilon=epsilon)

return kdtw_kernel


class TimeSeriesKernelKMeans(BaseClusterer):
Expand Down Expand Up @@ -141,6 +254,9 @@ def _fit(self, X, y=None):
if self.verbose is True:
verbose = 1

if self.kernel == "kdtw":
self.kernel = factory_kdtw_kernel(channels=X.shape[1])

self._tslearn_kernel_k_means = TsLearnKernelKMeans(
n_clusters=self.n_clusters,
kernel=self.kernel,
Expand Down
24 changes: 24 additions & 0 deletions aeon/clustering/tests/test_kernel_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@

expected_results = [0, 0, 0, 0, 0]

expected_labels_kdtw = [0, 0, 0, 1, 2]

expected_iters_kdtw = 2

expected_results_kdtw = [0, 2, 0, 0, 0]


@pytest.mark.skipif(
not _check_estimator_deps(TimeSeriesKernelKMeans, severity="none"),
Expand All @@ -37,3 +43,21 @@ def test_kernel_k_means():

for val in proba:
assert np.count_nonzero(val == 1.0) == 1

kernel_kmeans_kdtw = TimeSeriesKernelKMeans(
kernel="kdtw",
random_state=1,
n_clusters=3,
kernel_params={"sigma": 2.0, "epsilon": 1e-4},
)
kernel_kmeans_kdtw.fit(X_train[0:max_train])
kdtw_results = kernel_kmeans_kdtw.predict(X_test[0:max_train])
kdtw_proba = kernel_kmeans_kdtw.predict_proba(X_test[0:max_train])

assert np.array_equal(kdtw_results, expected_results_kdtw)
assert kernel_kmeans_kdtw.n_iter_ == expected_iters_kdtw
assert np.array_equal(kernel_kmeans_kdtw.labels_, expected_labels_kdtw)
assert kdtw_proba.shape == (max_train, 3)

for val in kdtw_proba:
assert np.count_nonzero(val == 1.0) == 1