-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugs/1592 Resolved two bugs in BatchParallel Clustering #1593
Changes from 5 commits
0d7765d
0bd9280
ddc3ae3
cf24ee5
b58ce24
d40826c
8db2810
763b255
4fbf7e4
3c636cd
8e5fa37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,20 +19,24 @@ | |
""" | ||
|
||
|
||
def _initialize_plus_plus(X, n_clusters, p, random_state=None): | ||
def _initialize_plus_plus(X, n_clusters, p, random_state=None, max_samples=2**24 - 1): | ||
""" | ||
Auxiliary function: single-process k-means++/k-medians++ initialization in pytorch | ||
p is the norm used for computing distances | ||
""" | ||
if random_state is not None: | ||
torch.manual_seed(random_state) | ||
idxs = torch.zeros(n_clusters, dtype=torch.long, device=X.device) | ||
idxs[0] = torch.randint(0, X.shape[0], (1,)) | ||
for i in range(1, n_clusters): | ||
dist = torch.cdist(X, X[idxs[:i]], p=p) | ||
dist = torch.min(dist, dim=1)[0] | ||
idxs[i] = torch.multinomial(dist, 1) | ||
return X[idxs] | ||
if X.shape[0] <= max_samples: # torch's multinomial is limited to 2^24 categories | ||
idxs = torch.zeros(n_clusters, dtype=torch.long, device=X.device) | ||
idxs[0] = torch.randint(0, X.shape[0], (1,)) | ||
for i in range(1, n_clusters): | ||
dist = torch.cdist(X, X[idxs[:i]], p=p) | ||
dist = torch.min(dist, dim=1)[0] | ||
idxs[i] = torch.multinomial(dist, 1) | ||
return X[idxs] | ||
else: # if X is too large for the 2^24-bound, use a random subset of X | ||
idxs = torch.randint(0, X.shape[0], (max_samples,)) | ||
mrfh92 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return _initialize_plus_plus(X[idxs], n_clusters, p, random_state) | ||
|
||
|
||
def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None): | ||
|
@@ -289,7 +293,7 @@ def predict(self, x: DNDarray): | |
|
||
local_labels = _parallel_batched_kmex_predict( | ||
x.larray, self._cluster_centers.larray, self._p | ||
) | ||
).to(torch.int32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not do it the other way? Set the heat array to the proper output type? I get the argument that it is an unlikely number of clusters, but it could theoretically happen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also thought about this and my arguments for the chosen solution were:
|
||
labels = DNDarray( | ||
local_labels, | ||
gshape=(x.shape[0], 1), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some unsuspecting user could try to change this value to something higher, and encounter the limit on torch. Should we hard code it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this is already hard code as this is an auxiliary function that is not made available to the user directly.
The reason for introducing
max_samples
as a variable was to have some flexibility for adapting this in the future.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a comment in the functions description.