Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QR for non tall-skinny matrices and split=0 #1744

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
270 changes: 162 additions & 108 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple

from ..dndarray import DNDarray
from ..manipulations import concatenate
from .. import factories
from .. import communication
from ..types import float32, float64
Expand All @@ -31,7 +32,6 @@ def qr(
----------
A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case
Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported
For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns.
mode : str, optional
default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified.
"r" returns only R, with dimensions (min(M,N), N).
Expand All @@ -46,13 +46,17 @@ def qr(

- If ``A`` is distributed along the columns (A.split = 1), so will be ``Q`` and ``R``.

- If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`, but ``R`` won't be distributed, i.e. `R. split = None` and a full copy of ``R`` will be stored on each process.
- If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`. ``R`` won't be distributed, i.e. `R. split = None`, if ``A`` is tall-skinny, i.e., if
the largest local chunk of data of ``A`` has at least as many rows as columns. Otherwise, ``R`` will be distributed along the rows as well, i.e., `R.split = 0`.

Note that the argument `calc_q` allowed in earlier Heat versions is no longer supported; `calc_q = False` is equivalent to `mode = "r"`.
Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage.

Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on
the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used.
the backend. Both cases split=0 and split=1 build on a column-block-wise version of stabilized Gram-Schmidt orthogonalization.
For split=1 (-1, in the batched case), this is directly applied to the local arrays of the input array.
For split=0, a tall-skinny QR (TS-QR) is implemented for the case of tall-skinny matrices (i.e., the largest local chunk of data has at least as many rows as columns),
and extended to non tall-skinny matrices by applying a block-wise version of stabilized Gram-Schmidt orthogonalization.

References
-----------
Expand Down Expand Up @@ -181,121 +185,171 @@ def qr(
return QR(Q, R)

if A.split == A.ndim - 2:
# implementation of TS-QR for split = 0
# check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data)
if A.lshape_map[:, -2].max().item() < A.shape[-1]:
raise ValueError(
"A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub."
# check that data distribution is reasonable for TS-QR
# we regard a matrix with split = 0 as suitable for TS-QR is largest local chunk of data has at least as many rows as columns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# we regard a matrix with split = 0 as suitable for TS-QR is largest local chunk of data has at least as many rows as columns
# we regard a matrix with split = 0 as suitable for TS-QR if largest local chunk of data has at least as many rows as columns

biggest_number_of_local_rows = A.lshape_map[:, -2].max().item()
if biggest_number_of_local_rows < A.shape[-1]:
column_idx = torch.cumsum(A.lshape_map[:, -2], 0)
column_idx = column_idx[column_idx < A.shape[-1]]
column_idx = torch.cat(
(
torch.tensor([0], device=column_idx.device),
column_idx,
torch.tensor([A.shape[-1]], device=column_idx.device),
)
Comment on lines +192 to +199
Copy link
Contributor

@ClaudiaComito ClaudiaComito Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the DNDarray.counts_displs() method here,

_, column_idx = A.counts_displs()

(this returns a tuple though, and I think the final item [A.shape[-1]] needs to be added)

)
A_copy = A.copy()
R = A.copy()
# Block-wise Gram-Schmidt orthogonalization, applied to groups of columns
offset = 1 if A.shape[-1] <= A.shape[-2] else 2
for k in range(len(column_idx) - offset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, each iteration needs A_copy[..., :, column_idx[k] :] only. Would it make sense to free up memory progressively here by only keeping the necessary slice of A_copy?

# since we only consider a group of columns, TS QR is applied to a tall-skinny matrix
Qnew, Rnew = qr(
A_copy[..., :, column_idx[k] : column_idx[k + 1]],
mode="reduced",
procs_to_merge=procs_to_merge,
)

current_procs = [i for i in range(A.comm.size)]
current_comm = A.comm
local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank)
Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode)
R_loc = R_loc.contiguous() # required for all the communication ops lateron
if mode == "reduced":
leave_comm = current_comm.Split(current_comm.rank, A.comm.rank)

level = 1
while len(current_procs) > 1:
if A.comm.rank in current_procs and local_comm.size > 1:
# create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes
shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0)
if local_comm.rank == 0:
gathered_R_loc = torch.zeros(
(*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]),
device=R_loc.device,
dtype=R_loc.dtype,
# usual update of the remaining columns
if R.comm.rank == k:
R.larray[
...,
: (column_idx[k + 1] - column_idx[k]),
column_idx[k] : column_idx[k + 1],
] = Rnew.larray
if R.comm.rank > k:
R.larray[..., :, column_idx[k] : column_idx[k + 1]] *= 0
if k < len(column_idx) - 2:
coeffs = (
torch.transpose(Qnew.larray, -2, -1)
@ A_copy.larray[..., :, column_idx[k + 1] :]
)
counts = list(shapes_R_loc)
displs = torch.cumsum(
torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0
).tolist()[:-1]
else:
gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype)
counts = None
displs = None
# gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2)
# perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc
if local_comm.rank == 0:
previous_shape = R_loc.shape
Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode)
R_loc = R_loc.contiguous()
else:
Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype)
R.comm.Allreduce(communication.MPI.IN_PLACE, coeffs)
if R.comm.rank == k:
R.larray[..., :, column_idx[k + 1] :] = coeffs
A_copy.larray[..., :, column_idx[k + 1] :] -= Qnew.larray @ coeffs
if mode == "reduced":
if local_comm.rank == 0:
Q_buf = Q_buf.contiguous()
scattered_Q_buf = torch.empty(
R_loc.shape if local_comm.rank != 0 else previous_shape,
device=R_loc.device,
dtype=R_loc.dtype,
)
# scatter the Q_buf to all processes of the process group
local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2)
del gathered_R_loc, Q_buf
Q = Qnew if k == 0 else concatenate((Q, Qnew), axis=-1)
if A.shape[-1] < A.shape[-2]:
R = R[..., : A.shape[-1], :].balance()
if mode == "reduced":
return QR(Q, R)
else:
return QR(None, R)

# for each process in the current processes, broadcast the scattered_Q_buf of this process
# to all leaves (i.e. all original processes that merge to the current process)
if mode == "reduced" and leave_comm.size > 1:
else:
# in this case the input is tall-skinny and we apply the TS-QR algorithm
# it follows the implementation of TS-QR for split = 0
current_procs = [i for i in range(A.comm.size)]
current_comm = A.comm
local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank)
Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode)
R_loc = R_loc.contiguous() # required for all the communication ops lateron
if mode == "reduced":
leave_comm = current_comm.Split(current_comm.rank, A.comm.rank)

level = 1
while len(current_procs) > 1:
if A.comm.rank in current_procs and local_comm.size > 1:
# create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes
shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0)
if local_comm.rank == 0:
gathered_R_loc = torch.zeros(
(*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]),
device=R_loc.device,
dtype=R_loc.dtype,
)
counts = list(shapes_R_loc)
displs = torch.cumsum(
torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0
).tolist()[:-1]
else:
gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype)
counts = None
displs = None
# gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2)
# perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc
if local_comm.rank == 0:
previous_shape = R_loc.shape
Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode)
R_loc = R_loc.contiguous()
else:
Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype)
if mode == "reduced":
if local_comm.rank == 0:
Q_buf = Q_buf.contiguous()
scattered_Q_buf = torch.empty(
R_loc.shape if local_comm.rank != 0 else previous_shape,
device=R_loc.device,
dtype=R_loc.dtype,
)
# scatter the Q_buf to all processes of the process group
local_comm.Scatterv(
(Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2
)
del gathered_R_loc, Q_buf

# for each process in the current processes, broadcast the scattered_Q_buf of this process
# to all leaves (i.e. all original processes that merge to the current process)
if mode == "reduced" and leave_comm.size > 1:
try:
scattered_Q_buf_shape = scattered_Q_buf.shape
except UnboundLocalError:
scattered_Q_buf_shape = None
scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0)
if scattered_Q_buf_shape is not None:
# this is needed to ensure that only those Q_loc get updates that are actually part of the current process group
if leave_comm.rank != 0:
scattered_Q_buf = torch.empty(
scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype
)
leave_comm.Bcast(scattered_Q_buf, root=0)
# update the local Q_loc by multiplying it with the scattered_Q_buf
try:
scattered_Q_buf_shape = scattered_Q_buf.shape
Q_loc = Q_loc @ scattered_Q_buf
del scattered_Q_buf
except UnboundLocalError:
scattered_Q_buf_shape = None
scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0)
if scattered_Q_buf_shape is not None:
# this is needed to ensure that only those Q_loc get updates that are actually part of the current process group
if leave_comm.rank != 0:
scattered_Q_buf = torch.empty(
scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype
pass

# update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering
current_procs = [
current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0
]
if len(current_procs) > 1:
new_group = A.comm.group.Incl(current_procs)
current_comm = A.comm.Create_group(new_group)
if A.comm.rank in current_procs:
local_comm = communication.MPICommunication(
current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank)
)
leave_comm.Bcast(scattered_Q_buf, root=0)
# update the local Q_loc by multiplying it with the scattered_Q_buf
try:
Q_loc = Q_loc @ scattered_Q_buf
del scattered_Q_buf
except UnboundLocalError:
pass

# update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering
current_procs = [
current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0
]
if len(current_procs) > 1:
new_group = A.comm.group.Incl(current_procs)
current_comm = A.comm.Create_group(new_group)
if A.comm.rank in current_procs:
local_comm = communication.MPICommunication(
current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank)
)
if mode == "reduced":
leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank)
level += 1
# broadcast the final R_loc to all processes
R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1])
if A.comm.rank != 0:
R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device)
A.comm.Bcast(R_loc, root=0)
R = DNDarray(
R_loc,
gshape=R_gshape,
dtype=A.dtype,
split=None,
device=A.device,
comm=A.comm,
balanced=True,
)
if mode == "r":
Q = None
else:
Q = DNDarray(
Q_loc,
gshape=A.shape,
if mode == "reduced":
leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank)
level += 1
# broadcast the final R_loc to all processes
R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1])
if A.comm.rank != 0:
R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device)
A.comm.Bcast(R_loc, root=0)
R = DNDarray(
R_loc,
gshape=R_gshape,
dtype=A.dtype,
split=A.split,
split=None,
device=A.device,
comm=A.comm,
balanced=True,
)
return QR(Q, R)
if mode == "r":
Q = None
else:
Q = DNDarray(
Q_loc,
gshape=A.shape,
dtype=A.dtype,
split=A.split,
device=A.device,
comm=A.comm,
balanced=True,
)
return QR(Q, R)
15 changes: 9 additions & 6 deletions heat/core/linalg/tests/test_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def test_qr_split0(self):
for procs_to_merge in [0, 2, 3]:
for mode in ["reduced", "r"]:
# split = 0 can be handeled only for tall skinny matrices s.t. the local chunks are at least square too
for shape in [(40 * ht.MPI_WORLD.size + 1, 40), (40 * ht.MPI_WORLD.size, 20)]:
for shape in [
(20 * ht.MPI_WORLD.size + 1, 40 * ht.MPI_WORLD.size),
(20 * ht.MPI_WORLD.size, 20 * ht.MPI_WORLD.size),
(40 * ht.MPI_WORLD.size - 1, 20 * ht.MPI_WORLD.size),
]:
for dtype in [ht.float32, ht.float64]:
dtypetol = 1e-3 if dtype == ht.float32 else 1e-6
mat = ht.random.randn(*shape, dtype=dtype, split=split)
Expand Down Expand Up @@ -146,8 +150,11 @@ def test_batched_qr_split1(self):
self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6))

def test_batched_qr_split0(self):
ht.random.seed(424242)
# one batch dimension, float32 data type, "split = 0" (second last dimension)
x = ht.random.randn(8, ht.MPI_WORLD.size * 10 + 3, 9, dtype=ht.float32, split=1)
x = ht.random.randn(
8, ht.MPI_WORLD.size * 10 + 3, ht.MPI_WORLD.size * 10 - 1, dtype=ht.float32, split=1
)
q, r = ht.linalg.qr(x)
batched_id = ht.stack([ht.eye(q.shape[2], dtype=ht.float32) for _ in range(q.shape[0])])

Expand Down Expand Up @@ -178,7 +185,3 @@ def test_wronginputs(self):
# test wrong dtype
with self.assertRaises(TypeError):
ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32))
# test wrong shape for split=0
if ht.MPI_WORLD.size > 1:
with self.assertRaises(ValueError):
ht.linalg.qr(ht.zeros((10, 10), split=0))
Loading