Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mpi4py 4.x.x #1618

Merged
merged 11 commits into from
Sep 2, 2024
13 changes: 9 additions & 4 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def mpi_type_of(cls, dtype: torch.dtype) -> MPI.Datatype:
def mpi_type_and_elements_of(
cls,
obj: Union[DNDarray, torch.Tensor],
counts: Tuple[int],
counts: Optional[Tuple[int]],
displs: Tuple[int],
is_contiguous: Optional[bool],
) -> Tuple[MPI.Datatype, Tuple[int, ...]]:
Expand Down Expand Up @@ -289,7 +289,7 @@ def mpi_type_and_elements_of(
if is_contiguous:
if counts is None:
return mpi_type, elements
factor = np.prod(obj.shape[1:])
factor = np.prod(obj.shape[1:], dtype=np.int32)
return (
mpi_type,
(
Expand Down Expand Up @@ -326,14 +326,15 @@ def as_mpi_memory(cls, obj) -> MPI.memory:
obj : torch.Tensor
The tensor to be converted into a MPI memory view.
"""
# TODO: MPI.memory might be depraecated in future versions of mpi4py. The following code might need to be adapted and use MPI.buffer instead.
return MPI.memory.fromaddress(obj.data_ptr(), 0)

@classmethod
def as_buffer(
cls,
obj: torch.Tensor,
counts: Tuple[int] = None,
displs: Tuple[int] = None,
counts: Optional[Tuple[int]] = None,
displs: Optional[Tuple[int]] = None,
is_contiguous: Optional[bool] = None,
) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]:
"""
Expand All @@ -356,6 +357,10 @@ def as_buffer(
obj.unsqueeze_(-1)
squ = True

if counts is not None:
counts = tuple(int(c) for c in counts)
if displs is not None:
displs = tuple(int(d) for d in displs)
mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous)
mpi_mem = cls.as_mpi_memory(obj)
if squ:
Expand Down
2 changes: 1 addition & 1 deletion heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3478,7 +3478,7 @@ def vsplit(x: DNDarray, indices_or_sections: Iterable) -> List[DNDarray, ...]:
return split(x, indices_or_sections, 0)


def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
def resplit(arr: DNDarray, axis: Optional[int] = None) -> DNDarray:
"""
Out-of-place redistribution of the content of the `DNDarray`. Allows to "unsplit" (i.e. gather) all values from all
nodes, as well as to define a new axis along which the array is split without changes to the values.
Expand Down
9 changes: 7 additions & 2 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import shutil
import fnmatch
import unittest

import heat as ht
from .test_suites.basic_test import TestCase
Expand Down Expand Up @@ -148,6 +149,10 @@ def test_load_csv(self):
with self.assertRaises(TypeError):
ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0)

@unittest.skipIf(
len(TestCase.get_hostnames()) > 1 and not os.environ.get("TMPDIR"),
"Requires the environment variable 'TMPDIR' to point to a globally accessible path. Otherwise the test will be skiped on multi-node setups.",
)
def test_save_csv(self):
for rnd_type in [
(ht.random.randint, ht.types.int32),
Expand All @@ -160,11 +165,11 @@ def test_save_csv(self):
for headers in [None, ["# This", "# is a", "# test."]]:
for shape in [(1, 1), (10, 10), (20, 1), (1, 20), (25, 4), (4, 25)]:
if rnd_type[0] == ht.random.randint:
data = rnd_type[0](
data: ht.DNDarray = rnd_type[0](
-1000, 1000, size=shape, dtype=rnd_type[1], split=split
)
else:
data = rnd_type[0](
data: ht.DNDarray = rnd_type[0](
shape[0],
shape[1],
split=split,
Expand Down
12 changes: 12 additions & 0 deletions heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import platform
import os

from heat.core import dndarray, MPICommunication, MPI, types, factories
Expand All @@ -12,6 +13,7 @@
class TestCase(unittest.TestCase):
__comm = MPICommunication()
__device = None
_hostnames: list[str] = None

@property
def comm(self):
Expand Down Expand Up @@ -62,6 +64,16 @@ def get_rank(self):
def get_size(self):
return self.comm.size

@classmethod
def get_hostnames(cls):
if not cls._hostnames:
if platform.system() == "Windows":
host = platform.uname().node
else:
host = os.uname()[1]
cls._hostnames = set(cls.__comm.handle.allgather(host))
return cls._hostnames

def assert_array_equal(self, heat_array, expected_array):
"""
Check if the heat_array is equivalent to the expected_array. Therefore first the split heat_array is compared to
Expand Down
15 changes: 11 additions & 4 deletions heat/optim/tests/test_dp_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
from heat.core.tests.test_suites.basic_test import TestCase


@unittest.skipIf(
int(os.getenv("SLURM_NNODES", "1")) < 2 or torch.cuda.device_count() == 0,
"only supported for GPUs and at least two nodes",
)
class TestDASO(TestCase):

@unittest.skipUnless(
len(TestCase.get_hostnames()) >= 2
and torch.cuda.device_count() > 1
and TestCase.device == "cuda",
f"only supported for GPUs and at least two nodes, Nodes = {TestCase.get_hostnames()}, torch.cuda.device_count() = {torch.cuda.device_count()}, rank = {ht.MPI_WORLD.rank}",
)
def test_daso(self):
import heat.nn.functional as F
import heat.optim as optim

print(
f"rank = {ht.MPI_WORLD.rank}, host = {os.uname()[1]}, torch.cuda.device_count() = {torch.cuda.device_count()}, torch.cuda.current_device() = {torch.cuda.current_device()}, NNodes = {len(TestCase.get_hostnames())}"
)

class Model(ht.nn.Module):
def __init__(self):
super(Model, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"Topic :: Scientific/Engineering",
],
install_requires=[
"mpi4py>=3.0.0, <4.0.0",
"mpi4py>=3.0.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run it on 3.x.x?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by that. I installed mpi4py 4.0.0 manually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked wether the changes are backwards compatible with the older versions? If not we have to raise the minimum dependency or add some more lines reflecting that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I did not check for backwards compatibility. I think the only problem with that one might be the rename of buffer/memory, but I will run some more tests.

"numpy>=1.22.0, <2",
"torch>=2.0.0, <2.4.1",
"scipy>=1.10.0",
Expand Down
Loading