Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs/1592 Resolved two bugs in BatchParallel Clustering #1593

Merged
22 changes: 13 additions & 9 deletions heat/cluster/batchparallelclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
"""


def _initialize_plus_plus(X, n_clusters, p, random_state=None):
def _initialize_plus_plus(X, n_clusters, p, random_state=None, max_samples=2**24 - 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some unsuspecting user could try to change this value to something higher, and encounter the limit on torch. Should we hard code it?

Copy link
Collaborator Author

@mrfh92 mrfh92 Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is already hard code as this is an auxiliary function that is not made available to the user directly.
The reason for introducing max_samples as a variable was to have some flexibility for adapting this in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a comment in the functions description.

"""
Auxiliary function: single-process k-means++/k-medians++ initialization in pytorch
p is the norm used for computing distances
"""
if random_state is not None:
torch.manual_seed(random_state)
idxs = torch.zeros(n_clusters, dtype=torch.long, device=X.device)
idxs[0] = torch.randint(0, X.shape[0], (1,))
for i in range(1, n_clusters):
dist = torch.cdist(X, X[idxs[:i]], p=p)
dist = torch.min(dist, dim=1)[0]
idxs[i] = torch.multinomial(dist, 1)
return X[idxs]
if X.shape[0] <= max_samples: # torch's multinomial is limited to 2^24 categories
idxs = torch.zeros(n_clusters, dtype=torch.long, device=X.device)
idxs[0] = torch.randint(0, X.shape[0], (1,))
for i in range(1, n_clusters):
dist = torch.cdist(X, X[idxs[:i]], p=p)
dist = torch.min(dist, dim=1)[0]
idxs[i] = torch.multinomial(dist, 1)
return X[idxs]
else: # if X is too large for the 2^24-bound, use a random subset of X
idxs = torch.randint(0, X.shape[0], (max_samples,))
mrfh92 marked this conversation as resolved.
Show resolved Hide resolved
return _initialize_plus_plus(X[idxs], n_clusters, p, random_state)


def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None):
Expand Down Expand Up @@ -289,7 +293,7 @@ def predict(self, x: DNDarray):

local_labels = _parallel_batched_kmex_predict(
x.larray, self._cluster_centers.larray, self._p
)
).to(torch.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do it the other way? Set the heat array to the proper output type? I get the argument that it is an unlikely number of clusters, but it could theoretically happen.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also thought about this and my arguments for the chosen solution were:

  • int32 saves 50% of memory compared to int64 during further processing of the outcome of the clustering
  • in theory, more than int32 cluster centers can be thought of, but in practice this is completely out of scope as the runtime of our clustering algorithms heavily depend on the number of cluster centers and also the reason for doing clustering is usually to get an insight in the structure of data by grouping them into a comparably small number of clusters.

labels = DNDarray(
local_labels,
gshape=(x.shape[0], 1),
Expand Down
6 changes: 5 additions & 1 deletion heat/cluster/tests/test_batchparallelclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mpi4py import MPI

from ...core.tests.test_suites.basic_test import TestCase
from ..batchparallelclustering import _kmex, _BatchParallelKCluster
from ..batchparallelclustering import _kmex, _initialize_plus_plus, _BatchParallelKCluster

# test BatchParallelKCluster base class and auxiliary functions

Expand All @@ -32,6 +32,10 @@ def test_kmex(self):
init = torch.rand(2, 3)
_kmex(X, 2, 2, init, max_iter, tol)

def test_initialize_plus_plus(self):
X = torch.rand(100, 3)
_initialize_plus_plus(X, 3, 2, random_state=None, max_samples=50)

def test_BatchParallelKClustering(self):
with self.assertRaises(TypeError):
_BatchParallelKCluster(2, 10, "++", 100, 1e-2, random_state=3.14, n_procs_to_merge=None)
Expand Down