-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add QR for non tall-skinny matrices and split=0
#1744
base: main
Are you sure you want to change the base?
Changes from all commits
704d8f0
7580552
6430afd
7b26be1
3a77070
b6bd730
d7f4be5
76ed6fb
c5c228e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from typing import Tuple | ||
|
||
from ..dndarray import DNDarray | ||
from ..manipulations import concatenate | ||
from .. import factories | ||
from .. import communication | ||
from ..types import float32, float64 | ||
|
@@ -31,7 +32,6 @@ def qr( | |
---------- | ||
A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case | ||
Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported | ||
For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns. | ||
mode : str, optional | ||
default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified. | ||
"r" returns only R, with dimensions (min(M,N), N). | ||
|
@@ -46,13 +46,17 @@ def qr( | |
|
||
- If ``A`` is distributed along the columns (A.split = 1), so will be ``Q`` and ``R``. | ||
|
||
- If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`, but ``R`` won't be distributed, i.e. `R. split = None` and a full copy of ``R`` will be stored on each process. | ||
- If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`. ``R`` won't be distributed, i.e. `R. split = None`, if ``A`` is tall-skinny, i.e., if | ||
the largest local chunk of data of ``A`` has at least as many rows as columns. Otherwise, ``R`` will be distributed along the rows as well, i.e., `R.split = 0`. | ||
|
||
Note that the argument `calc_q` allowed in earlier Heat versions is no longer supported; `calc_q = False` is equivalent to `mode = "r"`. | ||
Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage. | ||
|
||
Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on | ||
the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used. | ||
the backend. Both cases split=0 and split=1 build on a column-block-wise version of stabilized Gram-Schmidt orthogonalization. | ||
For split=1 (-1, in the batched case), this is directly applied to the local arrays of the input array. | ||
For split=0, a tall-skinny QR (TS-QR) is implemented for the case of tall-skinny matrices (i.e., the largest local chunk of data has at least as many rows as columns), | ||
and extended to non tall-skinny matrices by applying a block-wise version of stabilized Gram-Schmidt orthogonalization. | ||
|
||
References | ||
----------- | ||
|
@@ -181,121 +185,171 @@ def qr( | |
return QR(Q, R) | ||
|
||
if A.split == A.ndim - 2: | ||
# implementation of TS-QR for split = 0 | ||
# check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data) | ||
if A.lshape_map[:, -2].max().item() < A.shape[-1]: | ||
raise ValueError( | ||
"A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." | ||
# check that data distribution is reasonable for TS-QR | ||
# we regard a matrix with split = 0 as suitable for TS-QR is largest local chunk of data has at least as many rows as columns | ||
biggest_number_of_local_rows = A.lshape_map[:, -2].max().item() | ||
if biggest_number_of_local_rows < A.shape[-1]: | ||
column_idx = torch.cumsum(A.lshape_map[:, -2], 0) | ||
column_idx = column_idx[column_idx < A.shape[-1]] | ||
column_idx = torch.cat( | ||
( | ||
torch.tensor([0], device=column_idx.device), | ||
column_idx, | ||
torch.tensor([A.shape[-1]], device=column_idx.device), | ||
) | ||
Comment on lines
+192
to
+199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use the
(this returns a tuple though, and I think the final item [A.shape[-1]] needs to be added) |
||
) | ||
A_copy = A.copy() | ||
R = A.copy() | ||
# Block-wise Gram-Schmidt orthogonalization, applied to groups of columns | ||
offset = 1 if A.shape[-1] <= A.shape[-2] else 2 | ||
for k in range(len(column_idx) - offset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, each iteration needs |
||
# since we only consider a group of columns, TS QR is applied to a tall-skinny matrix | ||
Qnew, Rnew = qr( | ||
A_copy[..., :, column_idx[k] : column_idx[k + 1]], | ||
mode="reduced", | ||
procs_to_merge=procs_to_merge, | ||
) | ||
|
||
current_procs = [i for i in range(A.comm.size)] | ||
current_comm = A.comm | ||
local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) | ||
Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode) | ||
R_loc = R_loc.contiguous() # required for all the communication ops lateron | ||
if mode == "reduced": | ||
leave_comm = current_comm.Split(current_comm.rank, A.comm.rank) | ||
|
||
level = 1 | ||
while len(current_procs) > 1: | ||
if A.comm.rank in current_procs and local_comm.size > 1: | ||
# create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes | ||
shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) | ||
if local_comm.rank == 0: | ||
gathered_R_loc = torch.zeros( | ||
(*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), | ||
device=R_loc.device, | ||
dtype=R_loc.dtype, | ||
# usual update of the remaining columns | ||
if R.comm.rank == k: | ||
R.larray[ | ||
..., | ||
: (column_idx[k + 1] - column_idx[k]), | ||
column_idx[k] : column_idx[k + 1], | ||
] = Rnew.larray | ||
if R.comm.rank > k: | ||
R.larray[..., :, column_idx[k] : column_idx[k + 1]] *= 0 | ||
if k < len(column_idx) - 2: | ||
coeffs = ( | ||
torch.transpose(Qnew.larray, -2, -1) | ||
@ A_copy.larray[..., :, column_idx[k + 1] :] | ||
) | ||
counts = list(shapes_R_loc) | ||
displs = torch.cumsum( | ||
torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0 | ||
).tolist()[:-1] | ||
else: | ||
gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) | ||
counts = None | ||
displs = None | ||
# gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes | ||
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) | ||
# perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc | ||
if local_comm.rank == 0: | ||
previous_shape = R_loc.shape | ||
Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode) | ||
R_loc = R_loc.contiguous() | ||
else: | ||
Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) | ||
R.comm.Allreduce(communication.MPI.IN_PLACE, coeffs) | ||
if R.comm.rank == k: | ||
R.larray[..., :, column_idx[k + 1] :] = coeffs | ||
A_copy.larray[..., :, column_idx[k + 1] :] -= Qnew.larray @ coeffs | ||
if mode == "reduced": | ||
if local_comm.rank == 0: | ||
Q_buf = Q_buf.contiguous() | ||
scattered_Q_buf = torch.empty( | ||
R_loc.shape if local_comm.rank != 0 else previous_shape, | ||
device=R_loc.device, | ||
dtype=R_loc.dtype, | ||
) | ||
# scatter the Q_buf to all processes of the process group | ||
local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2) | ||
del gathered_R_loc, Q_buf | ||
Q = Qnew if k == 0 else concatenate((Q, Qnew), axis=-1) | ||
if A.shape[-1] < A.shape[-2]: | ||
R = R[..., : A.shape[-1], :].balance() | ||
if mode == "reduced": | ||
return QR(Q, R) | ||
else: | ||
return QR(None, R) | ||
|
||
# for each process in the current processes, broadcast the scattered_Q_buf of this process | ||
# to all leaves (i.e. all original processes that merge to the current process) | ||
if mode == "reduced" and leave_comm.size > 1: | ||
else: | ||
# in this case the input is tall-skinny and we apply the TS-QR algorithm | ||
# it follows the implementation of TS-QR for split = 0 | ||
current_procs = [i for i in range(A.comm.size)] | ||
current_comm = A.comm | ||
local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) | ||
Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode) | ||
R_loc = R_loc.contiguous() # required for all the communication ops lateron | ||
if mode == "reduced": | ||
leave_comm = current_comm.Split(current_comm.rank, A.comm.rank) | ||
|
||
level = 1 | ||
while len(current_procs) > 1: | ||
if A.comm.rank in current_procs and local_comm.size > 1: | ||
# create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes | ||
shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) | ||
if local_comm.rank == 0: | ||
gathered_R_loc = torch.zeros( | ||
(*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), | ||
device=R_loc.device, | ||
dtype=R_loc.dtype, | ||
) | ||
counts = list(shapes_R_loc) | ||
displs = torch.cumsum( | ||
torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0 | ||
).tolist()[:-1] | ||
else: | ||
gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) | ||
counts = None | ||
displs = None | ||
# gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes | ||
local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) | ||
# perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc | ||
if local_comm.rank == 0: | ||
previous_shape = R_loc.shape | ||
Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode) | ||
R_loc = R_loc.contiguous() | ||
else: | ||
Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) | ||
if mode == "reduced": | ||
if local_comm.rank == 0: | ||
Q_buf = Q_buf.contiguous() | ||
scattered_Q_buf = torch.empty( | ||
R_loc.shape if local_comm.rank != 0 else previous_shape, | ||
device=R_loc.device, | ||
dtype=R_loc.dtype, | ||
) | ||
# scatter the Q_buf to all processes of the process group | ||
local_comm.Scatterv( | ||
(Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2 | ||
) | ||
del gathered_R_loc, Q_buf | ||
|
||
# for each process in the current processes, broadcast the scattered_Q_buf of this process | ||
# to all leaves (i.e. all original processes that merge to the current process) | ||
if mode == "reduced" and leave_comm.size > 1: | ||
try: | ||
scattered_Q_buf_shape = scattered_Q_buf.shape | ||
except UnboundLocalError: | ||
scattered_Q_buf_shape = None | ||
scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0) | ||
if scattered_Q_buf_shape is not None: | ||
# this is needed to ensure that only those Q_loc get updates that are actually part of the current process group | ||
if leave_comm.rank != 0: | ||
scattered_Q_buf = torch.empty( | ||
scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype | ||
) | ||
leave_comm.Bcast(scattered_Q_buf, root=0) | ||
# update the local Q_loc by multiplying it with the scattered_Q_buf | ||
try: | ||
scattered_Q_buf_shape = scattered_Q_buf.shape | ||
Q_loc = Q_loc @ scattered_Q_buf | ||
del scattered_Q_buf | ||
except UnboundLocalError: | ||
scattered_Q_buf_shape = None | ||
scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0) | ||
if scattered_Q_buf_shape is not None: | ||
# this is needed to ensure that only those Q_loc get updates that are actually part of the current process group | ||
if leave_comm.rank != 0: | ||
scattered_Q_buf = torch.empty( | ||
scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype | ||
pass | ||
|
||
# update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering | ||
current_procs = [ | ||
current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0 | ||
] | ||
if len(current_procs) > 1: | ||
new_group = A.comm.group.Incl(current_procs) | ||
current_comm = A.comm.Create_group(new_group) | ||
if A.comm.rank in current_procs: | ||
local_comm = communication.MPICommunication( | ||
current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) | ||
) | ||
leave_comm.Bcast(scattered_Q_buf, root=0) | ||
# update the local Q_loc by multiplying it with the scattered_Q_buf | ||
try: | ||
Q_loc = Q_loc @ scattered_Q_buf | ||
del scattered_Q_buf | ||
except UnboundLocalError: | ||
pass | ||
|
||
# update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering | ||
current_procs = [ | ||
current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0 | ||
] | ||
if len(current_procs) > 1: | ||
new_group = A.comm.group.Incl(current_procs) | ||
current_comm = A.comm.Create_group(new_group) | ||
if A.comm.rank in current_procs: | ||
local_comm = communication.MPICommunication( | ||
current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) | ||
) | ||
if mode == "reduced": | ||
leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) | ||
level += 1 | ||
# broadcast the final R_loc to all processes | ||
R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) | ||
if A.comm.rank != 0: | ||
R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) | ||
A.comm.Bcast(R_loc, root=0) | ||
R = DNDarray( | ||
R_loc, | ||
gshape=R_gshape, | ||
dtype=A.dtype, | ||
split=None, | ||
device=A.device, | ||
comm=A.comm, | ||
balanced=True, | ||
) | ||
if mode == "r": | ||
Q = None | ||
else: | ||
Q = DNDarray( | ||
Q_loc, | ||
gshape=A.shape, | ||
if mode == "reduced": | ||
leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) | ||
level += 1 | ||
# broadcast the final R_loc to all processes | ||
R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) | ||
if A.comm.rank != 0: | ||
R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) | ||
A.comm.Bcast(R_loc, root=0) | ||
R = DNDarray( | ||
R_loc, | ||
gshape=R_gshape, | ||
dtype=A.dtype, | ||
split=A.split, | ||
split=None, | ||
device=A.device, | ||
comm=A.comm, | ||
balanced=True, | ||
) | ||
return QR(Q, R) | ||
if mode == "r": | ||
Q = None | ||
else: | ||
Q = DNDarray( | ||
Q_loc, | ||
gshape=A.shape, | ||
dtype=A.dtype, | ||
split=A.split, | ||
device=A.device, | ||
comm=A.comm, | ||
balanced=True, | ||
) | ||
return QR(Q, R) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.