Skip to content

Commit

Permalink
Merge pull request #1262 from OceanParcels/support_installation_mpi_w…
Browse files Browse the repository at this point in the history
…ithout_kmeans

Support for running on systems with MPI but without sklearn
  • Loading branch information
erikvansebille authored Oct 20, 2022
2 parents ecda521 + d5babbd commit 56e988a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions parcels/collection/collectionsoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
try:
from sklearn.cluster import KMeans
except:
raise EnvironmentError('sklearn needs to be available if MPI is installed. '
'See http://oceanparcels.org/#parallel_install for more information')
KMeans = None


def _convert_to_flat_array(var):
Expand Down Expand Up @@ -82,8 +81,13 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, p
if (self._pu_indicators is None) or (len(self._pu_indicators) != len(lon)):
if mpi_rank == 0:
coords = np.vstack((lon, lat)).transpose()
kmeans = KMeans(n_clusters=mpi_size, random_state=0).fit(coords)
self._pu_indicators = kmeans.labels_
if KMeans:
kmeans = KMeans(n_clusters=mpi_size, random_state=0).fit(coords)
self._pu_indicators = kmeans.labels_
else: # assigning random labels if no KMeans (see https://github.com/OceanParcels/parcels/issues/1261)
logger.warning_once('sklearn needs to be available if MPI is installed. '
'See http://oceanparcels.org/#parallel_install for more information')
self._pu_indicators = np.randint(0, mpi_size, size=len(lon))
else:
self._pu_indicators = None
self._pu_indicators = mpi_comm.bcast(self._pu_indicators, root=0)
Expand Down

0 comments on commit 56e988a

Please sign in to comment.