Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: ht.array constructor respects implicit torch device when copy is set to false #1363

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b1203e7
fix: ht.array constructor respects implicit torch devices
JuanPedroGHM Feb 10, 2024
f1b568b
fix: double splitting on the factories.array method
JuanPedroGHM Feb 26, 2024
5525bb6
fix: added extra tests to cover found cases
JuanPedroGHM Feb 26, 2024
447083e
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Feb 28, 2024
b1ef6f9
Corrected typos, improved error messages, clearer treatment of `copy`…
JuanPedroGHM Mar 4, 2024
3a97d5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
afb7491
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Mar 4, 2024
6f878fa
wip: check if memory gets reused
JuanPedroGHM Mar 4, 2024
c505200
fix: replaced clone to contiguous
JuanPedroGHM Mar 14, 2024
fdb50f2
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Mar 25, 2024
1e14d7d
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Apr 2, 2024
820bf43
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Apr 8, 2024
91b7125
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Apr 15, 2024
95272ad
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
ClaudiaComito Apr 16, 2024
dffa4fa
fix: tutorial, missing target_map variable
JuanPedroGHM Apr 16, 2024
a84ba76
docs: copy example in tutorials
JuanPedroGHM Apr 16, 2024
1c08661
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
JuanPedroGHM Apr 17, 2024
7fcf288
Merge branch 'main' into bugs/1321-_Bug_factories_array_illegal_param…
ClaudiaComito Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 46 additions & 26 deletions heat/core/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .communication import MPI, sanitize_comm, Communication
from .devices import Device
from .dndarray import DNDarray
from .memory import sanitize_memory_layout
from .memory import sanitize_memory_layout, copy as memory_copy
from .sanitation import sanitize_in, sanitize_sequence
from .stride_tricks import sanitize_axis, sanitize_shape
from .types import datatype
Expand Down Expand Up @@ -287,19 +287,6 @@ def array(
11
[torch.LongStorage of size 6]
"""
# array already exists; no copy
if isinstance(obj, DNDarray):
if not copy:
if (
(dtype is None or dtype == obj.dtype)
and (split is None or split == obj.split)
and (is_split is None or is_split == obj.split)
and (device is None or device == obj.device)
):
return obj
# extract the internal tensor
obj = obj.larray

# sanitize the data type
if dtype is not None:
dtype = types.canonical_heat_type(dtype)
Expand All @@ -308,6 +295,38 @@ def array(
if device is not None:
device = devices.sanitize_device(device)

if split is not None and is_split is not None:
raise ValueError("split and is_split are mutually exclusive parameters")

# array already exists; no copy
if isinstance(obj, DNDarray):
if (
(dtype is None or dtype == obj.dtype)
and (split is None or split == obj.split)
and (is_split is None or is_split == obj.split)
and (device is None or device == obj.device)
):
if copy is True:
return memory_copy(obj)
else:
return obj
elif split is not None and obj.split is not None and split != obj.split:
raise ValueError(
f"'split' argument does not match existing 'split' dimention ({split} != {obj.split}).\nIf you are trying to create a new DNDarray with a new split from an existing DNDarray, use the function `ht.resplit()` instead."
)
elif is_split is not None and obj.split is not None and is_split != obj.split:
raise ValueError(
f"'is_split' and the split axis of the object do not match ({is_split} != {obj.split}).\nIf you are trying to resplit an existing DNDarray in-place, use the method `DNDarray.resplit_()` instead."
)
elif device is not None and device != obj.device and copy is False:

raise ValueError(
"argument `copy` is set to False, but copy of input object is necessary as the array is being copied across devices.\nUse the method `DNDarray.cpu()` or `DNDarray.gpu()` to move the array to the desired device."
)

# extract the internal tensor
obj = obj.larray

# initialize the array
if bool(copy):
if isinstance(obj, torch.Tensor):
Expand All @@ -333,20 +352,23 @@ def array(
(dtype is None or dtype == types.canonical_heat_type(obj.dtype))
and (
device is None
or device.torch_device
== str(getattr(obj, "device", devices.get_device().torch_device))
or device.torch_device.split(":")[0]
== str(getattr(obj, "device", devices.get_device().torch_device)).split(":")[0]
)
):
raise ValueError(
"argument `copy` is set to False, but copy of input object is necessary. \n Set copy=None to reuse the memory buffer whenever possible and allow for copies otherwise."
)
try:
obj = torch.as_tensor(
obj,
device=(
device.torch_device if device is not None else devices.get_device().torch_device
),
)
if not isinstance(obj, torch.Tensor):
obj = torch.as_tensor(
obj,
device=(
device.torch_device
if device is not None
else devices.get_device().torch_device
),
)
except RuntimeError:
raise TypeError(f"invalid data of type {type(obj)}")

Expand All @@ -359,7 +381,7 @@ def array(
obj = obj.type(torch_dtype)

# infer device from obj if not explicitly given
if device is None:
if device is None and hasattr(obj, "device"):
device = devices.sanitize_device(obj.device.type)

if str(obj.device) != device.torch_device:
Expand All @@ -383,8 +405,6 @@ def array(
# sanitize the split axes, ensure mutual exclusiveness
split = sanitize_axis(obj.shape, split)
is_split = sanitize_axis(obj.shape, is_split)
if split is not None and is_split is not None:
raise ValueError("split and is_split are mutually exclusive parameters")

# sanitize comm object
comm = sanitize_comm(comm)
Expand All @@ -402,7 +422,7 @@ def array(
elif split is not None:
# only keep local slice
_, _, slices = comm.chunk(gshape, split)
_ = obj[slices].clone()
_ = obj[slices].contiguous()
del obj

obj = sanitize_memory_layout(_, order=order)
Expand Down
63 changes: 63 additions & 0 deletions heat/core/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,69 @@ def test_array(self):
).all()
)

# distributed array, chunk local data (split), copy False, torch devices
array_2d = torch.tensor(
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]],
dtype=torch.double,
device=self.device.torch_device,
)
dndarray_2d = ht.array(array_2d, split=0, copy=False, dtype=ht.double)
self.assertIsInstance(dndarray_2d, ht.DNDarray)
self.assertEqual(dndarray_2d.dtype, ht.float64)
self.assertEqual(dndarray_2d.gshape, (3, 3))
self.assertEqual(len(dndarray_2d.lshape), 2)
self.assertLessEqual(dndarray_2d.lshape[0], 3)
self.assertEqual(dndarray_2d.lshape[1], 3)
self.assertEqual(dndarray_2d.split, 0)
self.assertTrue(
(
dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device)
).all()
JuanPedroGHM marked this conversation as resolved.
Show resolved Hide resolved
)
# Check that the array is not a copy, (only really works when the array is not split)
if ht.communication.MPI_WORLD.size == 1:
self.assertIs(dndarray_2d.larray, array_2d)

# The array should not change as all properties match
dndarray_2d_new = ht.array(dndarray_2d, split=0, copy=False, dtype=ht.double)
self.assertIsInstance(dndarray_2d_new, ht.DNDarray)
self.assertEqual(dndarray_2d_new.dtype, ht.float64)
self.assertEqual(dndarray_2d_new.gshape, (3, 3))
self.assertEqual(len(dndarray_2d_new.lshape), 2)
self.assertLessEqual(dndarray_2d_new.lshape[0], 3)
self.assertEqual(dndarray_2d_new.lshape[1], 3)
self.assertEqual(dndarray_2d_new.split, 0)
self.assertTrue(
(
dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device)
).all()
)
# Reuse the same array
self.assertIs(dndarray_2d_new.larray, dndarray_2d.larray)

# Should throw exeception because of resplit it causes a resplit
with self.assertRaises(ValueError):
dndarray_2d_new = ht.array(dndarray_2d, split=1, copy=False, dtype=ht.double)

# The array should not change as all properties match
dndarray_2d_new = ht.array(dndarray_2d, is_split=0, copy=False, dtype=ht.double)
self.assertIsInstance(dndarray_2d_new, ht.DNDarray)
self.assertEqual(dndarray_2d_new.dtype, ht.float64)
self.assertEqual(dndarray_2d_new.gshape, (3, 3))
self.assertEqual(len(dndarray_2d_new.lshape), 2)
self.assertLessEqual(dndarray_2d_new.lshape[0], 3)
self.assertEqual(dndarray_2d_new.lshape[1], 3)
self.assertEqual(dndarray_2d_new.split, 0)
self.assertTrue(
(
dndarray_2d.larray == torch.tensor([1.0, 2.0, 3.0], device=self.device.torch_device)
).all()
)

# Should throw exeception because of array is split along another dimension
with self.assertRaises(ValueError):
dndarray_2d_new = ht.array(dndarray_2d, is_split=1, copy=False, dtype=ht.double)

# distributed array, partial data (is_split)
if ht.communication.MPI_WORLD.rank == 0:
split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]]
Expand Down
Loading
Loading