Skip to content

Commit

Permalink
Merge pull request #1618 from helmholtz-analytics/fix/mpi4py-4-support
Browse files Browse the repository at this point in the history
Support mpi4py 4.x.x
  • Loading branch information
JuanPedroGHM authored Sep 2, 2024
2 parents cbdf49a + b005901 commit affdebf
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 12 deletions.
13 changes: 9 additions & 4 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def mpi_type_of(cls, dtype: torch.dtype) -> MPI.Datatype:
def mpi_type_and_elements_of(
cls,
obj: Union[DNDarray, torch.Tensor],
counts: Tuple[int],
counts: Optional[Tuple[int]],
displs: Tuple[int],
is_contiguous: Optional[bool],
) -> Tuple[MPI.Datatype, Tuple[int, ...]]:
Expand Down Expand Up @@ -289,7 +289,7 @@ def mpi_type_and_elements_of(
if is_contiguous:
if counts is None:
return mpi_type, elements
factor = np.prod(obj.shape[1:])
factor = np.prod(obj.shape[1:], dtype=np.int32)
return (
mpi_type,
(
Expand Down Expand Up @@ -326,14 +326,15 @@ def as_mpi_memory(cls, obj) -> MPI.memory:
obj : torch.Tensor
The tensor to be converted into a MPI memory view.
"""
# TODO: MPI.memory might be depraecated in future versions of mpi4py. The following code might need to be adapted and use MPI.buffer instead.
return MPI.memory.fromaddress(obj.data_ptr(), 0)

@classmethod
def as_buffer(
cls,
obj: torch.Tensor,
counts: Tuple[int] = None,
displs: Tuple[int] = None,
counts: Optional[Tuple[int]] = None,
displs: Optional[Tuple[int]] = None,
is_contiguous: Optional[bool] = None,
) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]:
"""
Expand All @@ -356,6 +357,10 @@ def as_buffer(
obj.unsqueeze_(-1)
squ = True

if counts is not None:
counts = tuple(int(c) for c in counts)
if displs is not None:
displs = tuple(int(d) for d in displs)
mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous)
mpi_mem = cls.as_mpi_memory(obj)
if squ:
Expand Down
2 changes: 1 addition & 1 deletion heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3478,7 +3478,7 @@ def vsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]:
return split(x, indices_or_sections, 0)


def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
def resplit(arr: DNDarray, axis: Optional[int] = None) -> DNDarray:
"""
Out-of-place redistribution of the content of the `DNDarray`. Allows to "unsplit" (i.e. gather) all values from all
nodes, as well as to define a new axis along which the array is split without changes to the values.
Expand Down
9 changes: 7 additions & 2 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import shutil
import fnmatch
import unittest

import heat as ht
from .test_suites.basic_test import TestCase
Expand Down Expand Up @@ -148,6 +149,10 @@ def test_load_csv(self):
with self.assertRaises(TypeError):
ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0)

@unittest.skipIf(
len(TestCase.get_hostnames()) > 1 and not os.environ.get("TMPDIR"),
"Requires the environment variable 'TMPDIR' to point to a globally accessible path. Otherwise the test will be skiped on multi-node setups.",
)
def test_save_csv(self):
for rnd_type in [
(ht.random.randint, ht.types.int32),
Expand All @@ -160,11 +165,11 @@ def test_save_csv(self):
for headers in [None, ["# This", "# is a", "# test."]]:
for shape in [(1, 1), (10, 10), (20, 1), (1, 20), (25, 4), (4, 25)]:
if rnd_type[0] == ht.random.randint:
data = rnd_type[0](
data: ht.DNDarray = rnd_type[0](
-1000, 1000, size=shape, dtype=rnd_type[1], split=split
)
else:
data = rnd_type[0](
data: ht.DNDarray = rnd_type[0](
shape[0],
shape[1],
split=split,
Expand Down
12 changes: 12 additions & 0 deletions heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import platform
import os

from heat.core import dndarray, MPICommunication, MPI, types, factories
Expand All @@ -12,6 +13,7 @@
class TestCase(unittest.TestCase):
__comm = MPICommunication()
__device = None
_hostnames: list[str] = None

@property
def comm(self):
Expand Down Expand Up @@ -62,6 +64,16 @@ def get_rank(self):
def get_size(self):
return self.comm.size

@classmethod
def get_hostnames(cls):
if not cls._hostnames:
if platform.system() == "Windows":
host = platform.uname().node
else:
host = os.uname()[1]
cls._hostnames = set(cls.__comm.handle.allgather(host))
return cls._hostnames

def assert_array_equal(self, heat_array, expected_array):
"""
Check if the heat_array is equivalent to the expected_array. Therefore first the split heat_array is compared to
Expand Down
15 changes: 11 additions & 4 deletions heat/optim/tests/test_dp_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
from heat.core.tests.test_suites.basic_test import TestCase


@unittest.skipIf(
int(os.getenv("SLURM_NNODES", "1")) < 2 or torch.cuda.device_count() == 0,
"only supported for GPUs and at least two nodes",
)
class TestDASO(TestCase):

@unittest.skipUnless(
len(TestCase.get_hostnames()) >= 2
and torch.cuda.device_count() > 1
and TestCase.device == "cuda",
f"only supported for GPUs and at least two nodes, Nodes = {TestCase.get_hostnames()}, torch.cuda.device_count() = {torch.cuda.device_count()}, rank = {ht.MPI_WORLD.rank}",
)
def test_daso(self):
import heat.nn.functional as F
import heat.optim as optim

print(
f"rank = {ht.MPI_WORLD.rank}, host = {os.uname()[1]}, torch.cuda.device_count() = {torch.cuda.device_count()}, torch.cuda.current_device() = {torch.cuda.current_device()}, NNodes = {len(TestCase.get_hostnames())}"
)

class Model(ht.nn.Module):
def __init__(self):
super(Model, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"Topic :: Scientific/Engineering",
],
install_requires=[
"mpi4py>=3.0.0, <4.0.0",
"mpi4py>=3.0.0",
"numpy>=1.22.0, <2",
"torch>=2.0.0, <2.4.1",
"scipy>=1.10.0",
Expand Down

0 comments on commit affdebf

Please sign in to comment.