Skip to content

Commit c1c830b

Browse files
authored
[ENH] Adds kdtw kernel support for kernelkmeans (aeon-toolkit#2645)
* Adds kdtw kernel support for kernelkmeans * Code refactor * Adds tests for kdtw clustering * minor changes * minor changes
1 parent ee0b6bf commit c1c830b

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

aeon/clustering/_kernel_k_means.py

+116
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,122 @@
33
from typing import Optional, Union
44

55
import numpy as np
6+
from numba import njit
67
from numpy.random import RandomState
78

89
from aeon.clustering.base import BaseClusterer
10+
from aeon.distances.pointwise._squared import squared_pairwise_distance
11+
12+
13+
@njit(cache=True, fastmath=True)
14+
def _kdtw_lk(x, y, local_kernel):
15+
channels = np.shape(x)[1]
16+
padding_vector = np.zeros((1, channels))
17+
18+
x = np.concatenate((padding_vector, x), axis=0)
19+
y = np.concatenate((padding_vector, y), axis=0)
20+
21+
x_timepoints, _ = np.shape(x)
22+
y_timepoints, _ = np.shape(y)
23+
24+
cost_matrix = np.zeros((x_timepoints, y_timepoints))
25+
cumulative_dp_diag = np.zeros((x_timepoints, y_timepoints))
26+
diagonal_weights = np.zeros(max(x_timepoints, y_timepoints))
27+
28+
min_timepoints = min(x_timepoints, y_timepoints)
29+
diagonal_weights[1] = 1.0
30+
for i in range(1, min_timepoints):
31+
diagonal_weights[i] = local_kernel[i - 1, i - 1]
32+
33+
cost_matrix[0, 0] = 1
34+
cumulative_dp_diag[0, 0] = 1
35+
36+
for i in range(1, x_timepoints):
37+
cost_matrix[i, 1] = cost_matrix[i - 1, 1] * local_kernel[i - 1, 2]
38+
cumulative_dp_diag[i, 1] = cumulative_dp_diag[i - 1, 1] * diagonal_weights[i]
39+
40+
for j in range(1, y_timepoints):
41+
cost_matrix[1, j] = cost_matrix[1, j - 1] * local_kernel[2, j - 1]
42+
cumulative_dp_diag[1, j] = cumulative_dp_diag[1, j - 1] * diagonal_weights[j]
43+
44+
for i in range(1, x_timepoints):
45+
for j in range(1, y_timepoints):
46+
local_cost = local_kernel[i - 1, j - 1]
47+
cost_matrix[i, j] = (
48+
cost_matrix[i - 1, j]
49+
+ cost_matrix[i, j - 1]
50+
+ cost_matrix[i - 1, j - 1]
51+
) * local_cost
52+
if i == j:
53+
cumulative_dp_diag[i, j] = (
54+
cumulative_dp_diag[i - 1, j - 1] * local_cost
55+
+ cumulative_dp_diag[i - 1, j] * diagonal_weights[i]
56+
+ cumulative_dp_diag[i, j - 1] * diagonal_weights[j]
57+
)
58+
else:
59+
cumulative_dp_diag[i, j] = (
60+
cumulative_dp_diag[i - 1, j] * diagonal_weights[i]
61+
+ cumulative_dp_diag[i, j - 1] * diagonal_weights[j]
62+
)
63+
cost_matrix = cost_matrix + cumulative_dp_diag
64+
return cost_matrix[x_timepoints - 1, y_timepoints - 1]
65+
66+
67+
def kdtw(x, y, sigma=1.0, epsilon=1e-3):
68+
"""
69+
Callable kernel function for KernelKMeans.
70+
71+
Parameters
72+
----------
73+
X: np.ndarray, of shape (n_timepoints, n_channels)
74+
First time series sample.
75+
y: np.ndarray, of shape (n_timepoints, n_channels)
76+
Second time series sample.
77+
sigma : float, default=1.0
78+
Parameter controlling the width of the exponential local kernel. Smaller sigma
79+
values lead to a sharper decay of similarity with increasing distance.
80+
epsilon : float, default=1e-3
81+
A small constant added for numerical stability to avoid zero values in the
82+
local kernel matrix.
83+
84+
Returns
85+
-------
86+
similarity : float
87+
A scalar value representing the computed KDTW similarity between the two time
88+
series. Higher values indicate greater similarity.
89+
"""
90+
distance = squared_pairwise_distance(x, y)
91+
local_kernel = (np.exp(-distance / sigma) + epsilon) / (3 * (1 + epsilon))
92+
return _kdtw_lk(x, y, local_kernel)
93+
94+
95+
def factory_kdtw_kernel(channels: int):
96+
"""
97+
Return a kdtw kernel callable function that flattened samples to (T, channels).
98+
99+
Parameters
100+
----------
101+
channels: int
102+
Number of channels per timepoint.
103+
104+
Returns
105+
-------
106+
kdtw_kernel : callable
107+
A callable kernel function that computes the KDTW similarity between two
108+
time series samples. The function signature is the same as the kdtw
109+
function.
110+
"""
111+
112+
def kdtw_kernel(x, y, sigma=1.0, epsilon=1e-3):
113+
if x.ndim == 1:
114+
T = x.size // channels
115+
x = x.reshape(T, channels)
116+
if y.ndim == 1:
117+
T = y.size // channels
118+
y = y.reshape(T, channels)
119+
return kdtw(x, y, sigma=sigma, epsilon=epsilon)
120+
121+
return kdtw_kernel
9122

10123

11124
class TimeSeriesKernelKMeans(BaseClusterer):
@@ -141,6 +254,9 @@ def _fit(self, X, y=None):
141254
if self.verbose is True:
142255
verbose = 1
143256

257+
if self.kernel == "kdtw":
258+
self.kernel = factory_kdtw_kernel(channels=X.shape[1])
259+
144260
self._tslearn_kernel_k_means = TsLearnKernelKMeans(
145261
n_clusters=self.n_clusters,
146262
kernel=self.kernel,

aeon/clustering/tests/test_kernel_k_means.py

+24
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313

1414
expected_results = [0, 0, 0, 0, 0]
1515

16+
expected_labels_kdtw = [0, 0, 0, 1, 2]
17+
18+
expected_iters_kdtw = 2
19+
20+
expected_results_kdtw = [0, 2, 0, 0, 0]
21+
1622

1723
@pytest.mark.skipif(
1824
not _check_estimator_deps(TimeSeriesKernelKMeans, severity="none"),
@@ -37,3 +43,21 @@ def test_kernel_k_means():
3743

3844
for val in proba:
3945
assert np.count_nonzero(val == 1.0) == 1
46+
47+
kernel_kmeans_kdtw = TimeSeriesKernelKMeans(
48+
kernel="kdtw",
49+
random_state=1,
50+
n_clusters=3,
51+
kernel_params={"sigma": 2.0, "epsilon": 1e-4},
52+
)
53+
kernel_kmeans_kdtw.fit(X_train[0:max_train])
54+
kdtw_results = kernel_kmeans_kdtw.predict(X_test[0:max_train])
55+
kdtw_proba = kernel_kmeans_kdtw.predict_proba(X_test[0:max_train])
56+
57+
assert np.array_equal(kdtw_results, expected_results_kdtw)
58+
assert kernel_kmeans_kdtw.n_iter_ == expected_iters_kdtw
59+
assert np.array_equal(kernel_kmeans_kdtw.labels_, expected_labels_kdtw)
60+
assert kdtw_proba.shape == (max_train, 3)
61+
62+
for val in kdtw_proba:
63+
assert np.count_nonzero(val == 1.0) == 1

0 commit comments

Comments
 (0)