Skip to content

Commit

Permalink
Merge branch 'main' into bugs/1232-_Bug_Ensure_NumPy-compatibility_of…
Browse files Browse the repository at this point in the history
…_test_statistics_py
  • Loading branch information
mrfh92 authored Jan 22, 2024
2 parents b770ab2 + f8fd26a commit 6a5115c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
12 changes: 9 additions & 3 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def qr(
a: DNDarray,
tiles_per_proc: Union[int, torch.Tensor] = 1,
tiles_per_proc: Union[int, torch.Tensor] = 2,
calc_q: bool = True,
overwrite_a: bool = False,
) -> Tuple[DNDarray, DNDarray]:
Expand All @@ -30,7 +30,8 @@ def qr(
a : DNDarray
Array which will be decomposed
tiles_per_proc : int or torch.Tensor, optional
Number of tiles per process to operate on,
Number of tiles per process to operate on
We highly recommend to use tiles_per_proc > 1, as the choice 1 might result in an error in certain situations (in particular for split=0).
calc_q : bool, optional
Whether or not to calculate Q.
If ``True``, function returns ``(Q, R)``.
Expand Down Expand Up @@ -89,6 +90,11 @@ def qr(
if len(a.shape) != 2:
raise ValueError("Array 'a' must be 2 dimensional")

if a.split == 0 and tiles_per_proc == 1:
raise Warning(
"Using tiles_per_proc=1 with split=0 can result in an error. We highly recommend to use tiles_per_proc > 1."
)

QR = collections.namedtuple("QR", "Q, R")

if a.split is None:
Expand Down Expand Up @@ -898,7 +904,7 @@ def __split1_qr_loop(
except AttributeError:
q1, r1 = r_tiles[dcol, dcol].qr(some=False)

r_tiles.arr.comm.Bcast(q1.clone(), root=diag_process)
r_tiles.arr.comm.Bcast(q1.clone(memory_format=torch.contiguous_format), root=diag_process)
r_tiles[dcol, dcol] = r1
# apply q1 to the trailing matrix (other processes)

Expand Down
10 changes: 5 additions & 5 deletions heat/core/linalg/tests/test_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_qr_sp0_ext(self):
sp = 0
for m in range(50, st_whole.shape[0] + 1, 1):
for n in range(50, st_whole.shape[1] + 1, 1):
for t in range(1, 3):
for t in range(2, 3):
st = st_whole[:m, :n].clone()
a_comp = ht.array(st, split=0)
a = ht.array(st, split=sp)
Expand All @@ -37,7 +37,7 @@ def test_qr_sp1_ext(self):
sp = 1
for m in range(50, st_whole.shape[0] + 1, 1):
for n in range(50, st_whole.shape[1] + 1, 1):
for t in range(1, 3):
for t in range(2, 3):
st = st_whole[:m, :n].clone()
a_comp = ht.array(st, split=0)
a = ht.array(st, split=sp)
Expand All @@ -50,7 +50,7 @@ def test_qr(self):
m, n = 20, 40
st = torch.randn(m, n, device=self.device.torch_device, dtype=torch.float)
a_comp = ht.array(st, split=0)
for t in range(1, 3):
for t in range(2, 3):
for sp in range(2):
a = ht.array(st, split=sp, dtype=torch.float)
qr = a.qr(tiles_per_proc=t)
Expand All @@ -60,7 +60,7 @@ def test_qr(self):
m, n = 40, 40
st1 = torch.randn(m, n, device=self.device.torch_device)
a_comp1 = ht.array(st1, split=0)
for t in range(1, 3):
for t in range(2, 3):
for sp in range(2):
a1 = ht.array(st1, split=sp)
qr1 = a1.qr(tiles_per_proc=t)
Expand All @@ -70,7 +70,7 @@ def test_qr(self):
m, n = 40, 20
st2 = torch.randn(m, n, dtype=torch.double, device=self.device.torch_device)
a_comp2 = ht.array(st2, split=0, dtype=ht.double)
for t in range(1, 3):
for t in range(2, 3):
for sp in range(2):
a2 = ht.array(st2, split=sp)
qr2 = a2.qr(tiles_per_proc=t)
Expand Down
2 changes: 1 addition & 1 deletion heat/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""Indicates Heat's main version."""
minor: int = 4
"""Indicates feature extension."""
micro: int = 0
micro: int = 1
"""Indicates revisions for bugfixes."""
extension: str = "dev"
"""Indicates special builds, e.g. for specific hardware."""
Expand Down

0 comments on commit 6a5115c

Please sign in to comment.