Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/469 kmeans rework #470

Merged
2 changes: 1 addition & 1 deletion examples/lasso/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import heat as ht
from matplotlib import pyplot as plt
from sklearn import datasets
import heat.core.regression.lasso as lasso
import heat.regression.lasso as lasso
import plotfkt

# read scikit diabetes data set
Expand Down
5 changes: 3 additions & 2 deletions heat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import core
from . import cluster
from . import regression
from . import spatial
from .core import *
from .core import __version__
from .core import cluster
from .core import regression
from . import utils
File renamed without changes.
37 changes: 17 additions & 20 deletions heat/core/cluster/kmeans.py → heat/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def cluster_centers_(self):
Coordinates of cluster centers. If the algorithm stops before fully converging (see tol and max_iter),
these will not be consistent with labels_.
"""
return self._cluster_centers.squeeze(axis=0).T
return self._cluster_centers

@property
def labels_(self):
Expand Down Expand Up @@ -113,7 +113,7 @@ def _initialize_cluster_centers(self, X):
# Samples will be equally distributed drawn from all involved processes
_, displ, _ = X.comm.counts_displs_shape(shape=X.shape, axis=0)
centroids = ht.empty(
(X.shape[1], self.n_clusters), split=None, device=X.device, comm=X.comm
(self.n_clusters, X.shape[1]), split=None, device=X.device, comm=X.comm
)
if (X.split is None) or (X.split == 0):
for i in range(self.n_clusters):
Expand All @@ -132,12 +132,12 @@ def _initialize_cluster_centers(self, X):
idx = sample - displ[proc]
xi = ht.array(X.lloc[idx, :], device=X.device, comm=X.comm)
xi.comm.Bcast(xi, root=proc)
centroids[:, i] = xi
centroids[i, :] = xi

else:
raise NotImplementedError("Not implemented for other splitting-axes")

self._cluster_centers = centroids.expand_dims(axis=0)
self._cluster_centers = centroids

# directly passed centroids
elif isinstance(self.init, ht.DNDarray):
Expand All @@ -147,14 +147,13 @@ def _initialize_cluster_centers(self, X):
)
if self.init.shape[0] != self.n_clusters or self.init.shape[1] != X.shape[1]:
raise ValueError("passed centroids do not match cluster count or data shape")
self._cluster_centers = self.init.resplit(None).T.expand_dims(axis=0)
self._cluster_centers = self.init.resplit(None)

# kmeans++, smart centroid guessing
elif self.init == "kmeans++":
if (X.split is None) or (X.split == 0):
X = X.expand_dims(axis=2)
centroids = ht.empty(
(1, X.shape[1], self.n_clusters), split=None, device=X.device, comm=X.comm
centroids = ht.zeros(
(self.n_clusters, X.shape[1]), split=None, device=X.device, comm=X.comm
)
sample = ht.random.randint(0, X.shape[0] - 1).item()
_, displ, _ = X.comm.counts_displs_shape(shape=X.shape, axis=0)
Expand All @@ -166,15 +165,13 @@ def _initialize_cluster_centers(self, X):
x0 = ht.zeros(X.shape[1], dtype=X.dtype, device=X.device, comm=X.comm)
if X.comm.rank == proc:
idx = sample - displ[proc]
x0 = ht.array(X.lloc[idx, :, 0], device=X.device, comm=X.comm)
x0 = ht.array(X.lloc[idx, :], device=X.device, comm=X.comm)
x0.comm.Bcast(x0, root=proc)
centroids[0, :, 0] = x0

centroids[0, :] = x0
for i in range(1, self.n_clusters):
distances = ((X - centroids[:, :, :i]) ** 2).sum(axis=1, keepdim=True)
D2 = distances.min(axis=2)
distances = ht.spatial.distances.cdist(X, centroids, quadratic_expansion=True)
D2 = distances.min(axis=1)
D2.resplit_(axis=None)
D2 = D2.squeeze()
prob = D2 / D2.sum()
x = ht.random.rand().item()
sample = 0
Expand All @@ -192,9 +189,10 @@ def _initialize_cluster_centers(self, X):
xi = ht.zeros(X.shape[1], dtype=X.dtype)
if X.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(X.lloc[idx, :, 0], device=X.device, comm=X.comm)
xi = ht.array(X.lloc[idx, :], device=X.device, comm=X.comm)
xi.comm.Bcast(xi, root=proc)
centroids[0, :, i] = xi
centroids[i, :] = xi

else:
raise NotImplementedError("Not implemented for other splitting-axes")

Expand All @@ -217,8 +215,8 @@ def _fit_to_cluster(self, X):
Training instances to cluster.
"""
# calculate the distance matrix and determine the closest centroid
distances = ((X - self._cluster_centers) ** 2).sum(axis=1, keepdim=True)
matching_centroids = distances.argmin(axis=2, keepdim=True)
distances = ht.spatial.distances.cdist(X, self._cluster_centers, quadratic_expansion=True)
matching_centroids = distances.argmin(axis=1, keepdim=True)

return matching_centroids

Expand All @@ -240,7 +238,6 @@ def fit(self, X):
self._n_iter = 0
matching_centroids = ht.zeros((X.shape[0]), split=X.split, device=X.device, comm=X.comm)

X = X.expand_dims(axis=2)
new_cluster_centers = self._cluster_centers.copy()

# iteratively fit the points to the centroids
Expand All @@ -262,7 +259,7 @@ def fit(self, X):
)

# compute the new centroids
new_cluster_centers[:, :, i : i + 1] = assigned_points / points_in_cluster
new_cluster_centers[i : i + 1, :] = assigned_points / points_in_cluster

# check whether centroid movement has converged
self._inertia = ((self._cluster_centers - new_cluster_centers) ** 2).sum()
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@ def test_fit_iris(self):
# check whether the results are correct
self.assertIsInstance(kmeans.cluster_centers_, ht.DNDarray)
self.assertEqual(kmeans.cluster_centers_.shape, (k, iris.shape[1]))

iris_split = ht.load("heat/datasets/data/iris.csv", sep=";", split=1)
kmeans = ht.cluster.KMeans(n_clusters=k)

with self.assertRaises(NotImplementedError):
kmeans.fit(iris_split)

kmeans = ht.cluster.KMeans(n_clusters=k, init="random_number")
with self.assertRaises(ValueError):
kmeans.fit(iris_split)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_lasso(self):
X = X / ht.sqrt((ht.mean(X ** 2, axis=0)))
m, n = X.shape
# HeAT lasso instance
estimator = ht.core.regression.lasso.HeatLasso(max_iter=100, tol=None)
estimator = ht.regression.lasso.HeatLasso(max_iter=100, tol=None)
# check whether the results are correct
self.assertEqual(estimator.lam, 0.1)
self.assertTrue(estimator.theta is None)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_lasso(self):
X = X / torch.sqrt((torch.mean(X ** 2, 0)))
m, n = X.shape

estimator = ht.core.regression.lasso.PytorchLasso(max_iter=100, tol=None)
estimator = ht.regression.lasso.PytorchLasso(max_iter=100, tol=None)
# check whether the results are correct
self.assertEqual(estimator.lam, 0.1)
self.assertTrue(estimator.theta is None)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_lasso(self):
X = X / np.sqrt((np.mean(X ** 2, axis=0, keepdims=True)))
m, n = X.shape

estimator = ht.core.regression.lasso.NumpyLasso(max_iter=100, tol=None)
estimator = ht.regression.lasso.NumpyLasso(max_iter=100, tol=None)
# check whether the results are correct
self.assertEqual(estimator.lam, 0.1)
self.assertTrue(estimator.theta is None)
Expand Down
1 change: 1 addition & 0 deletions heat/spatial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .distances import *
38 changes: 38 additions & 0 deletions heat/spatial/distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import numpy as np

from .. import core

__all__ = ["cdist"]


def cdist(X, Y, quadratic_expansion=False):
# ToDo Case X==Y
result = core.factories.zeros(
(X.shape[0], Y.shape[0]),
dtype=core.types.float32,
split=X.split,
device=X.device,
comm=X.comm,
)

if X.split is not None:
if X.split != 0:
# ToDo: Find out if even possible
raise NotImplementedError("Splittings other than 0 or None currently not supported.")
if Y.split is not None:
# ToDo: This requires communication of calculated blocks, will be implemented with Similarity Matrix Calculation
raise NotImplementedError("Currently not supported")

if quadratic_expansion:
x_norm = (X._DNDarray__array ** 2).sum(1).view(-1, 1)
y_t = torch.transpose(Y._DNDarray__array, 0, 1)
y_norm = (Y._DNDarray__array ** 2).sum(1).view(1, -1)

dist = x_norm + y_norm - 2.0 * torch.mm(X._DNDarray__array, y_t)
result._DNDarray__array = torch.sqrt(torch.clamp(dist, 0.0, np.inf))

else:
result._DNDarray__array = torch.cdist(X._DNDarray__array, Y._DNDarray__array)

return result
Empty file added heat/spatial/tests/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions heat/spatial/tests/test_distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest
import os

import torch

import heat as ht
import numpy as np

if os.environ.get("DEVICE") == "gpu" and torch.cuda.is_available():
ht.use_device("gpu")
torch.cuda.set_device(torch.device(ht.get_device().torch_device))
else:
ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and torch.cuda.is_available():
device = ht.gpu.torch_device
ht_device = ht.gpu
torch.cuda.set_device(device)


class TestDistances(unittest.TestCase):
def test_cdist(self):
Cdebus marked this conversation as resolved.
Show resolved Hide resolved
split = None
X = ht.ones((4, 4), dtype=ht.float32, split=split, device=ht_device)
Y = ht.zeros((4, 4), dtype=ht.float32, split=None, device=ht_device)

res = ht.ones((4, 4), dtype=ht.float32, split=split) * 2

d = ht.spatial.cdist(X, Y, quadratic_expansion=False)
self.assertTrue(ht.equal(d, res))
self.assertEqual(d.split, split)

d = ht.spatial.cdist(X, Y, quadratic_expansion=True)
self.assertTrue(ht.equal(d, res))
self.assertEqual(d.split, split)

split = 0
X = ht.ones((4, 4), dtype=ht.float32, split=split)
res = ht.ones((4, 4), dtype=ht.float32, split=split) * 2

d = ht.spatial.cdist(X, Y, quadratic_expansion=False)
self.assertTrue(ht.equal(d, res))
self.assertEqual(d.split, split)

d = ht.spatial.cdist(X, Y, quadratic_expansion=True)
self.assertTrue(ht.equal(d, res))
self.assertEqual(d.split, split)

Y = ht.zeros((4, 4), dtype=ht.float32, split=split)

with self.assertRaises(NotImplementedError):
d = ht.spatial.cdist(X, Y, quadratic_expansion=False)
with self.assertRaises(NotImplementedError):
d = ht.spatial.cdist(X, Y, quadratic_expansion=True)

X = ht.ones((4, 4), dtype=ht.float32, split=1)
Y = ht.zeros((4, 4), dtype=ht.float32, split=None)

with self.assertRaises(NotImplementedError):
d = ht.spatial.cdist(X, Y, quadratic_expansion=False)
with self.assertRaises(NotImplementedError):
d = ht.spatial.cdist(X, Y, quadratic_expansion=True)
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
packages=[
"heat",
"heat.core",
"heat.core.cluster",
"heat.core.regression",
"heat.core.regression.lasso",
"heat.cluster",
"heat.regression",
"heat.regression.lasso",
"heat.utils",
],
data_files=["README.md", "LICENSE"],
Expand Down