diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 9c6604cdd4..4e3f0cea28 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -16,7 +16,7 @@ def qr( a: DNDarray, - tiles_per_proc: Union[int, torch.Tensor] = 1, + tiles_per_proc: Union[int, torch.Tensor] = 2, calc_q: bool = True, overwrite_a: bool = False, ) -> Tuple[DNDarray, DNDarray]: @@ -30,7 +30,8 @@ def qr( a : DNDarray Array which will be decomposed tiles_per_proc : int or torch.Tensor, optional - Number of tiles per process to operate on, + Number of tiles per process to operate on + We highly recommend to use tiles_per_proc > 1, as the choice 1 might result in an error in certain situations (in particular for split=0). calc_q : bool, optional Whether or not to calculate Q. If ``True``, function returns ``(Q, R)``. @@ -89,6 +90,11 @@ def qr( if len(a.shape) != 2: raise ValueError("Array 'a' must be 2 dimensional") + if a.split == 0 and tiles_per_proc == 1: + raise Warning( + "Using tiles_per_proc=1 with split=0 can result in an error. We highly recommend to use tiles_per_proc > 1." + ) + QR = collections.namedtuple("QR", "Q, R") if a.split is None: @@ -898,7 +904,7 @@ def __split1_qr_loop( except AttributeError: q1, r1 = r_tiles[dcol, dcol].qr(some=False) - r_tiles.arr.comm.Bcast(q1.clone(), root=diag_process) + r_tiles.arr.comm.Bcast(q1.clone(memory_format=torch.contiguous_format), root=diag_process) r_tiles[dcol, dcol] = r1 # apply q1 to the trailing matrix (other processes) diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index b19963ae1b..4cf7b2e3a1 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -22,7 +22,7 @@ def test_qr_sp0_ext(self): sp = 0 for m in range(50, st_whole.shape[0] + 1, 1): for n in range(50, st_whole.shape[1] + 1, 1): - for t in range(1, 3): + for t in range(2, 3): st = st_whole[:m, :n].clone() a_comp = ht.array(st, split=0) a = ht.array(st, split=sp) @@ -37,7 +37,7 @@ def test_qr_sp1_ext(self): sp = 1 for m in range(50, st_whole.shape[0] + 1, 1): for n in range(50, st_whole.shape[1] + 1, 1): - for t in range(1, 3): + for t in range(2, 3): st = st_whole[:m, :n].clone() a_comp = ht.array(st, split=0) a = ht.array(st, split=sp) @@ -50,7 +50,7 @@ def test_qr(self): m, n = 20, 40 st = torch.randn(m, n, device=self.device.torch_device, dtype=torch.float) a_comp = ht.array(st, split=0) - for t in range(1, 3): + for t in range(2, 3): for sp in range(2): a = ht.array(st, split=sp, dtype=torch.float) qr = a.qr(tiles_per_proc=t) @@ -60,7 +60,7 @@ def test_qr(self): m, n = 40, 40 st1 = torch.randn(m, n, device=self.device.torch_device) a_comp1 = ht.array(st1, split=0) - for t in range(1, 3): + for t in range(2, 3): for sp in range(2): a1 = ht.array(st1, split=sp) qr1 = a1.qr(tiles_per_proc=t) @@ -70,7 +70,7 @@ def test_qr(self): m, n = 40, 20 st2 = torch.randn(m, n, dtype=torch.double, device=self.device.torch_device) a_comp2 = ht.array(st2, split=0, dtype=ht.double) - for t in range(1, 3): + for t in range(2, 3): for sp in range(2): a2 = ht.array(st2, split=sp) qr2 = a2.qr(tiles_per_proc=t) diff --git a/heat/core/version.py b/heat/core/version.py index d680344436..9b655710a6 100644 --- a/heat/core/version.py +++ b/heat/core/version.py @@ -5,7 +5,7 @@ """Indicates Heat's main version.""" minor: int = 4 """Indicates feature extension.""" -micro: int = 0 +micro: int = 1 """Indicates revisions for bugfixes.""" extension: str = "dev" """Indicates special builds, e.g. for specific hardware."""