From 914e5c3a201f8957a2b9ceb8c0f60ec2b32ef8eb Mon Sep 17 00:00:00 2001 From: FOsterfeld <146953335+FOsterfeld@users.noreply.github.com> Date: Mon, 9 Sep 2024 05:55:45 +0200 Subject: [PATCH] Batched matrix multiplication. (#1261) * first implementation of the minimal solution split dimension is a batch dimension * access b.gshape[-2] only if input is not batched * fixed batched condition * throw a NotImplementedError for wrong split dimension on batched matmul * fixed dimension condition * added test for batched matmul with split dimension being a batch dimension * fixed condition for different batch dimensions * added some tests for correctly thrown errors * fixed test for batched matmul on gpu * test for batched matmul on gpu * remove unnecessary test with device=gpu * batched matmul with split==None for both matrices * implemented batched matmul for case split 00 * implemented batched matmul for case split 01 * implemented batched matmul for case split 11 * cleaned up code to return the result * added tests for the batched matmul * added batched matmul tests for float values * improved exception throwing: error message when only one matrix has split None * warn against the inefficient split cases in the matmul docstring * Update basics.py updated docs of matmul: warning on unfavourable split combinations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update basics.py extended docs on batched matmul * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed style complaints * Apply suggestions from code review Co-authored-by: Michael Tarnawa * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed documentation * updated matmul tests for new batch behavior * restructured code to remove code duplication of batched and unbatched cases generalized split 1-0 case to batched matrices * generalized the split case None-None to batched matrices small code restructuring added batched tests for all la split combinations * simplified the cases where not both matrices are split in la dimensions * generalized the None splits for batched matrices added None split to tests for batched matrices * removed unnecessary import * updated docstring * initialize random generator * refactored code for None splits --------- Co-authored-by: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hoppe Co-authored-by: Michael Tarnawa --- heat/core/linalg/basics.py | 1223 +++++++++++++------------ heat/core/linalg/tests/test_basics.py | 106 +++ heat/core/linalg/tests/test_qr.py | 9 + 3 files changed, 759 insertions(+), 579 deletions(-) diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index 03793185bd..53e5e94e82 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -423,26 +423,28 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: """ Matrix multiplication of two ``DNDarrays``: ``a@b=c`` or ``A@B=c``. Returns a tensor with the result of ``a@b``. The split dimension of the returned array is - typically the split dimension of a. However, if ``a.split=None`` then the the ``c.split`` will be - set as the split dimension of ``b``. If both are ``None`` then ``c.split`` is also ``None``. + typically the split dimension of a. If both are ``None`` and if ``allow_resplit=False`` then ``c.split`` is also ``None``. + + Batched inputs (with batch dimensions being leading dimensions) are allowed; see also the Notes below. Parameters - ---------- + ----------- a : DNDarray - 2 dimensional: :math:`L \\times P` + matrix :math:`L \\times P` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k [\\times L] \\times P` b : DNDarray - 2 dimensional: :math:`P \\times Q` + matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k \\times P [\\times Q]` allow_resplit : bool, optional Whether to distribute ``a`` in the case that both ``a.split is None`` and ``b.split is None``. Default is ``False``. If ``True``, if both are not split then ``a`` will be distributed in-place along axis 0. Notes - ----- - - If ``a`` is a split vector then the returned vector will be of shape (:math:`1xQ`) and will be split in the 1st dimension - - If ``b`` is a vector and either ``a`` or ``b`` is split, then the returned vector will be of shape (:math:`Lx1`) and will be split in the 0th dimension + ----------- + - For batched inputs, batch dimensions must coincide and if one matrix is split along a batch axis the other must be split along the same axis. + - If ``a`` or ``b`` is a (possibly batched) vector the result will also be a (possibly batched) vector. + - We recommend to avoid the particular split combinations ``1``-``0``, ``None``-``0``, and ``1``-``None`` (for ``a.split``-``b.split``) due to their comparably high memory consumption, if possible. Applying ``DNDarray.resplit_`` or ``heat.resplit`` on one of the two factors before calling ``matmul`` in these situations might improve performance of your code / might avoid memory bottlenecks. References - ---------- + ----------- [1] R. Gu, et al., "Improving Execution Concurrency of Large-scale Matrix Multiplication on Distributed Data-parallel Platforms," IEEE Transactions on Parallel and Distributed Systems, vol 28, no. 9. 2017. \n @@ -450,8 +452,8 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: Accelerators," 2018 IEEE International Parallel and Distributed Processing Symposium Workshops (IPDPSW), Vancouver, BC, 2018, pp. 877-882. - Example - ------- + Examples + ----------- >>> a = ht.ones((n, m), split=1) >>> a[0] = ht.arange(1, m + 1) >>> a[:, -1] = ht.arange(1, n + 1).larray @@ -473,7 +475,6 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: [1/1] tensor([[3., 1., 1., 1., 1., 1., 1.], [4., 1., 1., 1., 1., 1., 1.]]) >>> linalg.matmul(a, b).larray - [0/1] tensor([[18., 8., 9., 10.], [14., 6., 7., 8.], [18., 7., 8., 9.], @@ -488,15 +489,56 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: sanitation.sanitize_in(a) sanitation.sanitize_in(b) - if a.gshape[-1] != b.gshape[0]: - raise ValueError( - f"If the last dimension of a ({a.gshape[-1]}) is not the same size as the second-to-last dimension of b. ({b.gshape[-2]})" - ) + batch_dim = max(a.ndim, b.ndim) - 2 # -1 for vector vector multiplication + batched = batch_dim > 0 + + if batched and a.gshape[:batch_dim] != b.gshape[:batch_dim]: + raise ValueError("Batch dimensions must have the same shape!") + + batch_shape = a.gshape[:batch_dim] + + # if they are vectors they need to be expanded to be the proper dimensions + vector_flag_a = vector_flag_b = False + # if a.ndim >= 2 or b.ndim >= 2: # other case gets early out + if a.ndim == b.ndim - 1: + vector_flag_a = True + elif b.ndim == a.ndim - 1: + vector_flag_b = True + vector_flag = vector_flag_a or vector_flag_b # run squeeze at the end + + if not vector_flag and a.ndim != b.ndim: + raise ValueError("Number of batch dimensions must be the same!") + + if batch_dim >= 0: # not vector vector mult + na = a.gshape[-1] + mb = b.gshape[-2] if not vector_flag_b else b.gshape[-1] + if na != mb: + raise ValueError( + f"The last dimension of a ({a.gshape[-1]}) is not the same size as the second-to-last dimension of b. ({b.gshape[-2]})" + ) + + if batched: + # check for valid batched split of a and b + # if one is split along a batch axis, both matrices must be split along that axis + if ( + a.split is not None + and a.split < batch_dim + or b.split is not None + and b.split < batch_dim + ) and a.split != b.split: # not the same batch axis for split + raise NotImplementedError( + "Both input matrices have to be split along the same batch axis!" + ) + + comm = a.comm + ndim = max(a.ndim, b.ndim) + dev = a.device + tdev = dev.torch_device # determine if a larger type is needed for c c_type = types.promote_types(a.dtype, b.dtype) gpu_int_flag = False - if str(a.device)[:3] == "gpu": + if str(dev)[:3] == "gpu": og_type = c_type if c_type in [types.uint8, types.int8, types.int16, types.int32]: c_type = types.float32 @@ -506,596 +548,618 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: gpu_int_flag = True if a.dtype != c_type: - a = c_type(a, device=a.device) + a = c_type(a, device=dev) if b.dtype != c_type: - b = c_type(b, device=b.device) - - # early out for single-process setup, torch matmul - if a.comm.size == 1: - ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device) - if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret + b = c_type(b, device=dev) - if a.split is None and b.split is None: # matmul from torch - if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit: - # if either of A or B is a vector - ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device, comm=a.comm) - if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret + c = None - a.resplit_(0) - slice_0 = a.comm.chunk(a.shape, a.split)[2][0] - hold = a.larray @ b.larray - - c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm) - c.larray[slice_0.start : slice_0.stop, :] += hold - c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) - if gpu_int_flag: - c = og_type(c, device=a.device) - return c + # single-process setup, torch matmul + if a.comm.size == 1: + c = factories.array(torch.matmul(a.larray, b.larray), dtype=c_type, device=dev) - # if they are vectors they need to be expanded to be the proper dimensions - vector_flag = False # flag to run squeeze at the end of the function - if len(a.gshape) < 2 and len(b.gshape) < 2: + # early out for vector vector multiplication + # is this even covered in the tests? + # seems to be used in test_qr + elif a.ndim == 1 and b.ndim == 1: # make both split 0, do a local mm then a sum a.resplit_(0) b.resplit_(0) res = a.larray @ b.larray a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM) - ret = factories.array(res, split=None, device=a.device, comm=a.comm) - if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret - elif len(a.gshape) < 2: - a = manipulations.expand_dims(a, axis=0) - vector_flag = True - elif len(b.gshape) < 2: - b = manipulations.expand_dims(b, axis=1) - vector_flag = True - - split_0_flag = False - split_1_flag = False - split_01_flag = False - split_10_flag = False - - tdev = a.device.torch_device - - if ( - (a.split == 0 and b.split is None) or (a.split is None and b.split == 1) - ) and not vector_flag: - split = a.split if a.split is not None else b.split - split = split if not vector_flag else 0 - c = factories.zeros( - (a.gshape[-2], b.gshape[1]), split=split, dtype=c_type, device=a.device, comm=a.comm + c = factories.array(res, split=None, device=dev, comm=comm) + + elif a.split is None and b.split is None: # None-None + if allow_resplit and not vector_flag: # resplit a to 0 + a.resplit_(ndim - 2) + slice_0 = a.comm.chunk(a.shape, a.split)[2][0] + hold = a.larray @ b.larray + + c = factories.zeros( + (*batch_shape, a.gshape[-2], b.gshape[-1]), dtype=c_type, device=dev, comm=comm + ) + c.larray[..., slice_0.start : slice_0.stop, :] += hold + c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) + else: # torch matmul + c = factories.array( + torch.matmul(a.larray, b.larray), + dtype=c_type, + device=dev, + comm=comm, + ) + elif a.split is not None and a.split < batch_dim: # split in batch dimension + c = factories.array( + torch.matmul(a.larray, b.larray), + is_split=a.split, + dtype=c_type, + device=dev, + comm=comm, ) - c.larray += a.larray @ b.larray - ret = c if not vector_flag else c.squeeze() + if c is not None: # early out if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret + c = og_type(c, device=dev) - elif a.split == 1 and b.split is None: - c = torch.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=tdev) + return c - a_idx = a.comm.chunk(a.shape, a.split)[2] - c += a.larray @ b.larray[a_idx[1].start : a_idx[1].start + a.lshape[-1], :] - a.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) - c = c if not vector_flag else c.squeeze() - ret = factories.array( - c, split=a.split if b.gshape[1] > 1 else 0, device=a.device, comm=a.comm - ) - if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret + # vector expansions + if vector_flag_a: + a = manipulations.expand_dims(a, axis=batch_dim) + if vector_flag_b: + b = manipulations.expand_dims(b, axis=batch_dim + 1) + + c_shape = (*batch_shape, a.gshape[-2], b.gshape[-1]) + + # one split None => other one is la dimension + if a.split is None or b.split is None: + split = None + is_split = False + + if (a.split == ndim - 2 and b.split is None) or ( + a.split is None and b.split == ndim - 1 + ): # 0-None, None-1 + split = a.split if a.split is not None else b.split + is_split = True - elif a.split is None and b.split == 0: - c = torch.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=tdev) - b_idx = b.comm.chunk(b.shape, b.split)[2] - c += a.larray[:, b_idx[0].start : b_idx[0].start + b.lshape[0]] @ b.larray - b.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) - c = c if not vector_flag else c.squeeze() - ret = factories.array( - c, split=b.split if a.gshape[-2] > 1 else 0, device=a.device, comm=a.comm + c = a.larray @ b.larray + + elif a.split == ndim - 1 and b.split is None: # 1-None + split = a.split + + c = torch.zeros(c_shape, dtype=c_type.torch_type(), device=tdev) + + a_idx = comm.chunk(a.shape, a.split)[2] + c += ( + a.larray + @ b.larray[ + ..., a_idx[ndim - 1].start : a_idx[ndim - 1].start + a.lshape[ndim - 1], : + ] + ) + comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) + + elif a.split is None and b.split == ndim - 2: # None-0 + split = b.split + + c = torch.zeros(c_shape, dtype=c_type.torch_type(), device=tdev) + b_idx = b.comm.chunk(b.shape, b.split)[2] + c += ( + a.larray[..., b_idx[ndim - 2].start : b_idx[ndim - 2].start + b.lshape[ndim - 2]] + @ b.larray + ) + b.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) + + # early out + if vector_flag: # squeeze only in the la dimensions + # it could be sensible to resplit/rebalance in case a single node gets the whole vector + if split is not None and split > batch_dim: # split in dimension that gets squeezed + split = batch_dim + if c.numel() == 0: # empty tensor cannot be squeezed + c = torch.zeros((*batch_shape, 0), dtype=c_type.torch_type(), device=tdev) + else: + c = c.squeeze(batch_dim) + if c.ndim >= batch_dim + 2: + c = c.squeeze(batch_dim + 1) + + c = factories.array( + c, + split=split if not is_split else None, + is_split=split if is_split else None, + dtype=c_type, + device=dev, + comm=comm, ) - if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret - elif ( - a.split == 0 and b.split is None - ): # this case and the one below will only be reaching if one of them is a vector - c = torch.zeros((a.gshape[-2], b.lshape[1]), dtype=c_type.torch_type(), device=tdev) - a_idx = a.comm.chunk(a.shape, a.split)[2] - c[a_idx[0]] += a.larray @ b.larray - a.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) - c = c if not vector_flag else c.squeeze() - split = a.split if b.gshape[1] > 1 else 0 - split = split if not vector_flag else 0 - ret = factories.array(c, split=split, device=a.device, comm=a.comm) if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret + c = og_type(c, device=dev) - elif a.split is None and b.split == 1: - c = torch.zeros((a.gshape[-2], b.lshape[1]), dtype=c_type.torch_type(), device=tdev) - c += a.larray @ b.larray - c = c if not vector_flag else c.squeeze() - split = b.split if a.gshape[1] > 1 else 0 - split = split if not vector_flag else 0 - ret = factories.array(c, is_split=split, device=a.device, comm=a.comm) - if gpu_int_flag: - ret = og_type(ret, device=a.device) - return ret + return c - elif a.split == 0 and b.split == 0: - split_0_flag = True - elif a.split == 1 and b.split == 1: - split_1_flag = True - elif a.split == 0 and b.split == 1: - split_01_flag = True - elif a.split == 1 and b.split == 0: - split_10_flag = True else: - raise NotImplementedError("splits > 1 not implemented") - - # block sizes dont need to be the same. thy just need the same inner dimension (kB) - kB = 0 - rem_a, rem_b = [0] * 2 - if a.split == len(a.gshape) - 1 and b.split == len(a.gshape) - 2: - # if the split direction is the last dim in a and the first dim in b - # the max inner dim (kB) is the min value from the result of the integer division - # of the last dim of a/world size and the first dim of b/world size - kB = min([a.gshape[-1] // a.comm.size, b.gshape[0] // b.comm.size]) - elif a.split == len(a.gshape) - 2 and b.split == len(a.gshape) - 1: - kB = a.gshape[-1] - elif a.split == len(a.gshape) - 1: - kB = a.gshape[-1] // a.comm.size - elif b.split == len(a.gshape) - 2: - kB = b.gshape[0] // b.comm.size - kB = min(kB, a.gshape[-1]) - - if a.lshape[-1] % kB != 0 or (kB == 1 and a.lshape[-1] != 1): - rem_a = 1 - if b.lshape[0] % kB != 0 or (kB == 1 and b.lshape[-2] != 1): - rem_b = 1 - - # get the lshape map to determine what needs to be sent where as well as M and N - # lshape map dims -> {node, a=0, b=1, lshape} - lshape_map = torch.zeros((a.comm.size, 2, len(a.gshape)), dtype=int, device=tdev) - lshape_map[a.comm.rank, 0, :] = torch.tensor(a.lshape, device=tdev) - lshape_map[b.comm.rank, 1, :] = torch.tensor(b.lshape, device=tdev) - a.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) - - # find mB (first blocking dim for a) and nB (2nd blocking dim for b) - mB = lshape_map[:, 0, -2].min().item() - nB = lshape_map[:, 1, -1].min().item() - - # check for remaining dims in the outside dimensions - rem_a_out, rem_b_out = 0, 0 - if a.lshape[-2] % mB != 0 or (kB == 1 and a.lshape[-2] != 1): - rem_a_out = 1 - if b.lshape[-1] % nB != 0 or (kB == 1 and b.lshape[-1] != 1): - rem_b_out = 1 - - # get the flags from all processes - # rem_map dims guide -> {process number, a/b (0/1), True/False (1/0) - # if there is a remainder in this dimension - rem_map = torch.zeros((a.comm.size, 2, 2)) - rem_map[a.comm.rank, 0, :] = torch.tensor((rem_a_out, rem_a), device=tdev) - rem_map[a.comm.rank, 1, :] = torch.tensor((rem_b, rem_b_out), device=tdev) - rem_map_comm = a.comm.Iallreduce(MPI.IN_PLACE, rem_map, MPI.SUM) - - # index_map dims guide -> {process number, a=0/b=1, relevent 1st index, 2nd index} - index_map = torch.zeros((a.comm.size, 2, 2, 2), dtype=int, device=tdev) - a_idx = a.comm.chunk(a.shape, a.split)[2] - index_map[a.comm.rank, 0, 0] = torch.tensor((a_idx[0].start, a_idx[0].stop), device=tdev) - index_map[a.comm.rank, 0, 1] = torch.tensor((a_idx[1].start, a_idx[1].stop), device=tdev) - b_idx = b.comm.chunk(b.shape, b.split)[2] - index_map[b.comm.rank, 1, 0] = torch.tensor((b_idx[0].start, b_idx[0].stop), device=tdev) - index_map[b.comm.rank, 1, 1] = torch.tensor((b_idx[1].start, b_idx[1].stop), device=tdev) - - index_map_comm = a.comm.Iallreduce(MPI.IN_PLACE, index_map, MPI.SUM) - - # for the communication scheme, the output array needs to be created - c_shape = (a.gshape[-2], b.gshape[1]) - c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=a.device, comm=a.comm) - - # get the index map for c - c_index_map = factories.zeros((c.comm.size, 2, 2), device=a.device, comm=a.comm) - c_idx = c.comm.chunk(c.shape, c.split)[2] - c_index_map[c.comm.rank, 0, :] = (c_idx[0].start, c_idx[0].stop) - c_index_map[c.comm.rank, 1, :] = (c_idx[1].start, c_idx[1].stop) - c_wait = c.comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM) - - if a.split == 0: - a_block_map = torch.zeros( - (a.comm.size, a.shape[-2] // mB // a.comm.size, a.shape[-1] // kB, 2), - dtype=torch.int, - device=tdev, - ) - elif a.split == 1: - a_block_map = torch.zeros( - (a.comm.size, a.shape[-2] // mB, a.shape[-1] // kB // a.comm.size, 2), - dtype=torch.int, - device=tdev, - ) - # units-> [process, dim0 block number, dim1 block number, start coord] **indices are local - - # below is to handle the edge case where there is only one element in one dimension of a - a_d0_1s_flag, a_d1_1s_flag = False, False - if any(lshape_map[:, 0, :][:, 0] == 1): - a_d0_1s_flag = True - if any(lshape_map[:, 0, :][:, 1] == 1): - a_d1_1s_flag = True - - index_map_comm.Wait() - for pr in range(a.comm.size): - start0 = index_map[pr, 0, 0, 0].item() - stop0 = index_map[pr, 0, 0, 1].item() - start1 = index_map[pr, 0, 1, 0].item() - stop1 = index_map[pr, 0, 1, 1].item() - - for dim0 in range( - (stop0 - start0) // mB // a.comm.size if a_d0_1s_flag else (stop0 - start0) // mB - ): - # loop over the number of blocks in the 0th dimension - for dim1 in range( - (stop1 - start1) // kB // a.comm.size if a_d1_1s_flag else (stop1 - start1) // kB - ): - # loop over the number of blocks in the 1st dimension - a_block_map[pr, dim0, dim1] = torch.tensor( - (dim0 * mB, dim1 * kB), dtype=torch.int, device=tdev - ) - rem_map_comm.Wait() - if b.split == 0: - # the blocks are shifted in the 2nd dimension of A for as many remainders - # there are between the blocks in the first dim of B - cnt = 0 - for r in rem_map[:, 1, 0]: - if r.item(): - cnt += 1 - a_block_map[:, :, cnt:, 1] += 1 - - if b.split == 0: - b_block_map = torch.zeros( - (b.comm.size, b.shape[-2] // kB // b.comm.size, b.shape[-1] // nB, 2), - dtype=torch.int, - device=tdev, - ) - elif b.split == 1: - b_block_map = torch.zeros( - (b.comm.size, b.shape[-2] // kB, b.shape[-1] // nB // b.comm.size, 2), - dtype=torch.int, - device=tdev, - ) - # units-> [process, dim0 block number, dim1 block number, start coord] **indices are local - - # below is to handle the edge case where there is only one element in one dimension of b - b_d0_1s_flag, b_d1_1s_flag = False, False - if any(lshape_map[:, 1, :][:, 0] == 1): - b_d0_1s_flag = True - if any(lshape_map[:, 1, :][:, 1] == 1): - b_d1_1s_flag = True - - for pr in range(b.comm.size): - start0 = index_map[pr, 1, 0, 0].item() - stop0 = index_map[pr, 1, 0, 1].item() - start1 = index_map[pr, 1, 1, 0].item() - stop1 = index_map[pr, 1, 1, 1].item() - - # loop over the number of blocks in the 0th dimension - for dim0 in range( - (stop0 - start0) // kB // b.comm.size if b_d0_1s_flag else (stop0 - start0) // kB - ): - # loop over the number of blocks in the 1st dimension - for dim1 in range( - (stop1 - start1) // nB // b.comm.size if b_d1_1s_flag else (stop1 - start1) // nB + # block sizes dont need to be the same. they just need the same inner dimension (kB) + kB = 0 # redundant? + rem_a, rem_b = 0, 0 + if a.split == ndim - 1 and b.split == ndim - 2: # split 10 + # if the split direction is the last dim in a and the first dim in b + # the max inner dim (kB) is the min value from the result of the integer division + # of the last dim of a/world size and the first dim of b/world size + kB = min( + [a.gshape[-1] // comm.size, b.gshape[-2] // comm.size] + ) # a.gshape[-1] == b.gshape[-2] + elif a.split == ndim - 2 and b.split == ndim - 1: # split 01 + kB = a.gshape[-1] + elif a.split == ndim - 1: # split 11 + kB = a.gshape[-1] // comm.size + elif b.split == ndim - 2: # split 00 + kB = b.gshape[-2] // comm.size + kB = min( + kB, a.gshape[-1] + ) # shouldnt this always be kB and be the same as for split 11? + + if a.lshape[-1] % kB != 0 or ( + kB == 1 and a.lshape[-1] != 1 + ): # does kb == 1 imply a.lshape[-1] > 1? + rem_a = 1 + if b.lshape[-2] % kB != 0 or (kB == 1 and b.lshape[-2] != 1): + rem_b = 1 + + # get the lshape map to determine what needs to be sent where as well as M and N + # lshape map dims -> {node, a=0 | b=1, lshape} + lshape_map = torch.zeros((comm.size, 2, ndim), dtype=int, device=tdev) + lshape_map[comm.rank, 0, :] = torch.tensor(a.lshape, device=tdev) + lshape_map[comm.rank, 1, :] = torch.tensor(b.lshape, device=tdev) + comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) + + # find mB (first blocking dim for a) and nB (2nd blocking dim for b) + mB = lshape_map[:, 0, -2].min().item() # smallest number of local rows of a on a node + nB = lshape_map[:, 1, -1].min().item() # smallest number of local columns of b on a node + + # check for remaining dims in the outside dimensions + rem_a_out, rem_b_out = 0, 0 + if a.lshape[-2] % mB != 0 or (kB == 1 and a.lshape[-2] != 1): + rem_a_out = 1 + if b.lshape[-1] % nB != 0 or (kB == 1 and b.lshape[-1] != 1): + rem_b_out = 1 + + # get the flags from all processes + # rem_map dims guide -> {process number, a/b (0/1), dim0/dim1 (0/1), True/False (1/0) + # if there is a remainder in this dimension + rem_map = torch.zeros((comm.size, 2, 2)) + rem_map[comm.rank, 0, :] = torch.tensor((rem_a_out, rem_a), device=tdev) + rem_map[comm.rank, 1, :] = torch.tensor((rem_b, rem_b_out), device=tdev) + rem_map_comm = comm.Iallreduce(MPI.IN_PLACE, rem_map, MPI.SUM) + + # index_map dims guide -> {process number, a=0/b=1, relevant 1st index, 2nd index} + index_map = torch.zeros((comm.size, 2, 2, 2), dtype=int, device=tdev) + a_idx = comm.chunk(a.shape, a.split)[2] + index_map[comm.rank, 0, 0] = torch.tensor((a_idx[-2].start, a_idx[-2].stop), device=tdev) + index_map[comm.rank, 0, 1] = torch.tensor((a_idx[-1].start, a_idx[-1].stop), device=tdev) + b_idx = comm.chunk(b.shape, b.split)[2] + index_map[comm.rank, 1, 0] = torch.tensor((b_idx[-2].start, b_idx[-2].stop), device=tdev) + index_map[comm.rank, 1, 1] = torch.tensor((b_idx[-1].start, b_idx[-1].stop), device=tdev) + index_map_comm = comm.Iallreduce(MPI.IN_PLACE, index_map, MPI.SUM) + + # output: c = a @ b + # for the communication scheme, the output array needs to be created + c_shape = (*batch_shape, a.gshape[-2], b.gshape[-1]) + c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=dev, comm=comm) + + # get the index map for c + c_index_map = factories.zeros((c.comm.size, 2, 2), device=dev, comm=comm) + c_idx = comm.chunk(c.shape, c.split)[2] + c_index_map[comm.rank, 0, :] = (c_idx[-2].start, c_idx[-2].stop) + c_index_map[comm.rank, 1, :] = (c_idx[-1].start, c_idx[-1].stop) + c_index_map_comm = comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM) + + if a.split == ndim - 2: + a_block_map = torch.zeros( + (comm.size, a.shape[-2] // mB // comm.size, a.shape[-1] // kB, 2), + dtype=torch.int, + device=tdev, + ) + elif a.split == ndim - 1: # else should be equivalent at this point + a_block_map = torch.zeros( + (comm.size, a.shape[-2] // mB, a.shape[-1] // kB // comm.size, 2), + dtype=torch.int, + device=tdev, + ) + # units-> [process, dim0 block number, dim1 block number, start coord] **indices are local + + # below is to handle the edge case where there is only one element in one dimension of a + a_d0_1s_flag, a_d1_1s_flag = False, False + if any(lshape_map[:, 0, :][:, -2] == 1): + a_d0_1s_flag = True + if any(lshape_map[:, 0, :][:, -1] == 1): + a_d1_1s_flag = True + + index_map_comm.Wait() + for pr in range(comm.size): + start0 = index_map[pr, 0, 0, 0].item() + stop0 = index_map[pr, 0, 0, 1].item() + start1 = index_map[pr, 0, 1, 0].item() + stop1 = index_map[pr, 0, 1, 1].item() + + # maybe we could use torch.arange instead of this nested loop + for dim0 in range( + (stop0 - start0) // mB // comm.size if a_d0_1s_flag else (stop0 - start0) // mB ): - b_block_map[pr, dim0, dim1] = torch.tensor( - (dim0 * kB, dim1 * nB), dtype=torch.int, device=tdev - ) - - if a.split == 1: - cnt = 0 - # this loop will push the blocks in B to adjust for the remainders in A - for r in rem_map[:, 0, 1]: - if r.item(): - cnt += 1 - b_block_map[:, cnt:, :, 0] += 1 - - # work loop: loop over all processes (also will incorporate the remainder calculations) - c_wait.Wait() - - if split_0_flag: - # need to send b here and not a - # the rows on 'a' are complete, and the columns of 'b' are split - # locations of the remainders in b - b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False) - a_rem_locs0 = torch.nonzero(rem_map[:, 0, 0] == 1, as_tuple=False) - # remainders for a in the - a_node_rem_s0 = a.larray[:mB, kB : (kB + 1) * b_rem_locs0.numel() : kB + 1] - b_rem = torch.empty( - b_rem_locs0.numel(), b.lshape[-1], dtype=a.dtype.torch_type(), device=tdev - ) + # loop over the number of blocks in the 0th dimension + for dim1 in range( + (stop1 - start1) // kB // comm.size if a_d1_1s_flag else (stop1 - start1) // kB + ): + # loop over the number of blocks in the 1st dimension + a_block_map[pr, dim0, dim1] = torch.tensor( + (dim0 * mB, dim1 * kB), dtype=torch.int, device=tdev + ) + rem_map_comm.Wait() + + if b.split == ndim - 2: + # the blocks are shifted in the 2nd dimension of A for as many remainders + # there are between the blocks in the first dim of B + cnt = 0 + for r in rem_map[:, 1, 0]: + if r.item(): + cnt += 1 + # why increment by exactly 1? what can we assume about the lshapes on different nodes? + # can the sizes in the split dimension differ by more than 1? + a_block_map[:, :, cnt:, 1] += 1 + + b_block_map = torch.zeros( + (comm.size, b.shape[-2] // kB // comm.size, b.shape[-1] // nB, 2), + dtype=torch.int, + device=tdev, + ) + else: # b split 1 + b_block_map = torch.zeros( + (comm.size, b.shape[-2] // kB, b.shape[-1] // nB // comm.size, 2), + dtype=torch.int, + device=tdev, + ) + # units-> [process, dim0 block number, dim1 block number, start coord] **indices are local - # this if/elif/else loop is for the handling of - if a.comm.rank in a_rem_locs0: - # if A is split in dim0 and the rank has a remainder in this direction - r = a.larray[-1] - r_loc = index_map[a.comm.rank, 0, 0, 1] - index_map[a.comm.rank, 0, 0, 0] - 1 - else: - r = None - r_loc = None + # below is to handle the edge case where there is only one element in one dimension of b + b_d0_1s_flag, b_d1_1s_flag = False, False + if any(lshape_map[:, 1, :][:, -2] == 1): + b_d0_1s_flag = True + if any(lshape_map[:, 1, :][:, -1] == 1): + b_d1_1s_flag = True - req = {} - b_lp_data = {} for pr in range(b.comm.size): - # ibcast data on node first - if b.comm.rank == pr: - b_lp_data[pr] = b.larray.clone() - else: - b_lp_data[pr] = torch.zeros( - (lshape_map[pr, 1, 0].item(), lshape_map[pr, 1, 1].item()), - dtype=b.dtype.torch_type(), - device=tdev, - ) + start0 = index_map[pr, 1, 0, 0].item() + stop0 = index_map[pr, 1, 0, 1].item() + start1 = index_map[pr, 1, 1, 0].item() + stop1 = index_map[pr, 1, 1, 1].item() - # sending a to all nodes for b to operate with - req[pr] = b.comm.Ibcast(b_lp_data[pr], root=pr) - - # receive the data from the last loop and do the calculation with that - if pr != 0: - req[pr - 1].Wait() - # after receiving the last loop's bcast - __mm_c_block_setter( - b_proc=pr - 1, - a_proc=a.comm.rank, - a_data=a.larray, - b_data=b_lp_data[pr - 1], - b_block_map=b_block_map, - a_block_map=a_block_map, - b_split=b.split, - a_split=a.split, - mB=mB, - kB=kB, - nB=nB, - c=c.larray, - ) + # loop over the number of blocks in the 0th dimension + for dim0 in range( + (stop0 - start0) // kB // b.comm.size if b_d0_1s_flag else (stop0 - start0) // kB + ): + # loop over the number of blocks in the 1st dimension + for dim1 in range( + (stop1 - start1) // nB // b.comm.size + if b_d1_1s_flag + else (stop1 - start1) // nB + ): + b_block_map[pr, dim0, dim1] = torch.tensor( + (dim0 * kB, dim1 * nB), dtype=torch.int, device=tdev + ) - # check if there is a remainder on b in the previous node - # this loop is intended to get the remainders of b since it is the one being passed - if pr - 1 in b_rem_locs0: - # takes care of the remainders in b as well as dim0 of a - b_rem[pr - 1] = b_lp_data[pr - 1][-1] - - # this loop is to take care of the remainders in dim0 of A - if a_rem_locs0.nelement() != 0 and r_loc is not None: - st = index_map[pr - 1, 1, 0, 0].item() - sp = index_map[pr - 1, 1, 0, 1].item() - c.larray[r_loc.item(), :] += r[st:sp] @ b_lp_data[pr - 1] - del b_lp_data[pr - 1] - - # need to wait if its the last loop, also need to collect the remainders - if pr == b.comm.size - 1: - req[pr].Wait() - __mm_c_block_setter( - b_proc=pr, - a_proc=a.comm.rank, - a_data=a.larray, - b_data=b_lp_data[pr], - b_block_map=b_block_map, - a_block_map=a_block_map, - b_split=b.split, - a_split=a.split, - mB=mB, - kB=kB, - nB=nB, - c=c.larray, - ) - # check if there is a remainder on b on the last node (there shouldnt be) - if pr in b_rem_locs0: - # this is to save the data from B required by the remainders from dim1 of A - b_rem[pr] = b_lp_data[pr][-1] - - # this loop is to take care of the remainders in the 0th dimension of A - if a_rem_locs0.nelement() != 0 and r_loc is not None: - st = index_map[pr, 1, 0, 0].item() - sp = index_map[pr, 1, 0, 1].item() - - if split_01_flag: - st1 = index_map[pr, 1, 1, 0].item() - sp1 = index_map[pr, 1, 1, 1].item() - c.larray[r_loc.item(), st1:sp1] += r[st:sp] @ b_lp_data[pr] - else: - c.larray[r_loc.item(), :] += r[st:sp] @ b_lp_data[pr] - - # set the final blocks on the last loop, then adjust for the - # the remainders which were collected in b_rem - if b_rem_locs0.numel(): - c.larray[: a_node_rem_s0.shape[0]] += a_node_rem_s0 @ b_rem - del b_lp_data[pr] - - if vector_flag: - c_loc = c.larray.squeeze() - if c_loc.nelement() == 1: - c_loc = torch.tensor(c_loc, device=tdev) - - c = factories.array(c_loc, is_split=0, device=a.device, comm=a.comm) - if gpu_int_flag: - c = og_type(c, device=a.device) - return c + if a.split == ndim - 1: + cnt = 0 + # this loop will push the blocks in B to adjust for the remainders in A + for r in rem_map[:, 0, 1]: + if r.item(): + cnt += 1 + b_block_map[:, cnt:, :, 0] += 1 + + # work loop: loop over all processes (also will incorporate the remainder calculations) + c_index_map_comm.Wait() + + # split la dims 00 + if a.split == ndim - 2 and b.split == ndim - 2: + # need to send b here and not a + # the rows on 'a' are complete, and the columns of 'b' are split + # locations of the remainders in b + b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False) + a_rem_locs0 = torch.nonzero(rem_map[:, 0, 0] == 1, as_tuple=False) + # remainders for a in the + a_node_rem_s0 = a.larray[..., :mB, kB : (kB + 1) * b_rem_locs0.numel() : kB + 1] + b_rem = torch.empty( + (*batch_shape, b_rem_locs0.numel(), b.lshape[-1]), + dtype=a.dtype.torch_type(), + device=tdev, + ) - elif split_1_flag: - # for this case, a is sent to b - # this is because 'b' has complete columns and the rows of 'a' are split - # locations of the remainders in b - b_rem_locs1 = torch.nonzero(rem_map[:, 1, 1] == 1, as_tuple=False) - a_rem_locs1 = torch.nonzero(rem_map[:, 0, 1] == 1, as_tuple=False) - b_node_rem_s1 = b.larray[kB : (kB + 1) * a_rem_locs1.numel() : kB + 1, :nB] - # b_node_rem_s1 -> remainders for a in the - - a_rem = torch.empty( - a.lshape[-2], a_rem_locs1.numel(), dtype=b.dtype.torch_type(), device=tdev - ) - # this if/elif/else loop is for the handling of - if b.comm.rank in b_rem_locs1: - # if b is split in dim1 and the rank has a remainder in this direction - r = b.larray[:, -1] - r_loc = index_map[a.comm.rank, 1, 1, 1] - index_map[a.comm.rank, 1, 1, 0] - 1 - else: - r = None - r_loc = None - req = {} - a_lp_data = {} - for pr in range(a.comm.size): - # ibcast data on node first - if a.comm.rank == pr: - a_lp_data[pr] = a.larray.clone() + # this if/elif/else loop is for the handling of + if comm.rank in a_rem_locs0: + # if A is split in dim0 and the rank has a remainder in this direction + r = a.larray[..., -1, :].unsqueeze(-2) + # can we not just set r_loc = -1 instead? + r_loc = index_map[comm.rank, 0, 0, 1] - index_map[comm.rank, 0, 0, 0] - 1 else: - a_lp_data[pr] = torch.zeros( - (lshape_map[pr, 0, 0].item(), lshape_map[pr, 0, 1].item()), - dtype=a.dtype.torch_type(), - device=tdev, - ) - # sending a to all nodes for b to operate with - req[pr] = a.comm.Ibcast(a_lp_data[pr], root=pr) - # receive the data from the last loop and do the calculation with that - if pr != 0: - # after receiving the last loop's bcast - req[pr - 1].Wait() - __mm_c_block_setter( - a_proc=pr - 1, - b_proc=b.comm.rank, - a_data=a_lp_data[pr - 1], - b_data=b.larray, - b_block_map=b_block_map, - a_block_map=a_block_map, - b_split=b.split, - a_split=a.split, - mB=mB, - kB=kB, - nB=nB, - c=c.larray, - ) - # check if there is a remainder on b in the previous node - # this loop is intended to get the remainders of b since it is the one being passed - if pr - 1 in a_rem_locs1: - # takes care of the remainders in b as well as dim0 of a - a_rem[:, pr - 1] = a_lp_data[pr - 1][:, -1] - # this loop is to take care of the remainders in dim1 of B - if b_rem_locs1.nelement() != 0 and r_loc is not None: - st = index_map[pr - 1, 0, 1, 0].item() - sp = index_map[pr - 1, 0, 1, 1].item() - - c.larray[:, r_loc.item()] += (a_lp_data[pr - 1] @ r[st:sp, None]).flatten() - - del a_lp_data[pr - 1] - - # need to wait if its the last loop, also need to collect the remainders - if pr == b.comm.size - 1: - req[pr].Wait() - __mm_c_block_setter( - a_proc=pr, - b_proc=a.comm.rank, - a_data=a_lp_data[pr], - b_data=b.larray, - b_block_map=b_block_map, - a_block_map=a_block_map, - b_split=b.split, - a_split=a.split, - mB=mB, - kB=kB, - nB=nB, - c=c.larray, - ) - # check if there is a remainder on b on the last node (there shouldnt be) - if pr in a_rem_locs1: - # this is to save the data from B required by the remainders from dim1 of A - a_rem[:, pr] = a_lp_data[pr][:, -1] - # this loop is to take care of the remainders in the 0th dimension of A - if b_rem_locs1.nelement() != 0 and r_loc is not None: - st = index_map[pr, 0, 1, 0].item() - sp = index_map[pr, 0, 1, 1].item() - c.larray[:, r_loc.item()] += (a_lp_data[pr] @ r[st:sp, None]).flatten() - # set the final blocks on the last loop, then adjust for the the remainders which were collected in b_rem - if a_rem_locs1.numel(): - c.larray[:, : b_node_rem_s1.shape[1]] += a_rem @ b_node_rem_s1 - del a_lp_data[pr] - if vector_flag: - c = factories.array(c.larray.squeeze(), is_split=0, device=a.device, comm=a.comm) - if gpu_int_flag: - c = og_type(c, device=a.device) - return c + r = None + r_loc = None + + req = {} + b_lp_data = {} + for pr in range(comm.size): + # ibcast data on node first + if comm.rank == pr: + b_lp_data[pr] = b.larray.clone() + else: + b_lp_data[pr] = torch.zeros( + (*batch_shape, lshape_map[pr, 1, -2].item(), lshape_map[pr, 1, -1].item()), + dtype=b.dtype.torch_type(), + device=tdev, + ) + + # sending a to all nodes for b to operate with + req[pr] = comm.Ibcast(b_lp_data[pr], root=pr) + + # receive the data from the last loop and do the calculation with that + if pr != 0: + req[pr - 1].Wait() + # after receiving the last loop's bcast + __mm_c_block_setter( + b_proc=pr - 1, + a_proc=comm.rank, + a_data=a.larray, + b_data=b_lp_data[pr - 1], + b_block_map=b_block_map, + a_block_map=a_block_map, + b_split=0, + a_split=0, + mB=mB, + kB=kB, + nB=nB, + c=c.larray, + ) - elif split_01_flag: - # for this case there are no remainders which need to be taken care of - req = {} - b_lp_data = {} - for pr in range(a.comm.size): - # ibcast data on node first - if b.comm.rank == pr: - b_lp_data[pr] = b.larray.clone() + # check if there is a remainder on b in the previous node + # this loop is intended to get the remainders of b since it is the one being passed + if pr - 1 in b_rem_locs0: + # takes care of the remainders in b as well as dim0 of a + b_rem[..., pr - 1, :] = b_lp_data[pr - 1][..., -1, :] + + # this loop is to take care of the remainders in dim0 of a + if a_rem_locs0.nelement() != 0 and r_loc is not None: + st = index_map[pr - 1, 1, 0, 0].item() + sp = index_map[pr - 1, 1, 0, 1].item() + + c.larray[..., r_loc.item(), :] += ( + r[..., st:sp] @ b_lp_data[pr - 1] + ).squeeze(-2) + del b_lp_data[pr - 1] + + # need to wait if its the last loop, also need to collect the remainders + if pr == comm.size - 1: + req[pr].Wait() + __mm_c_block_setter( + b_proc=pr, + a_proc=comm.rank, + a_data=a.larray, + b_data=b_lp_data[pr], + b_block_map=b_block_map, + a_block_map=a_block_map, + b_split=0, + a_split=0, + mB=mB, + kB=kB, + nB=nB, + c=c.larray, + ) + # check if there is a remainder on b on the last node (there shouldnt be) + if pr in b_rem_locs0: + # this is to save the data from B required by the remainders from dim1 of A + b_rem[..., pr, :] = b_lp_data[pr][..., -1, :] + + # this loop is to take care of the remainders in the 0th dimension of A + if a_rem_locs0.nelement() != 0 and r_loc is not None: + st = index_map[pr, 1, 0, 0].item() + sp = index_map[pr, 1, 0, 1].item() # linear algebra dimension 0/1 + + # code not reachable? + # if split_01_flag: + if False: + st1 = index_map[pr, 1, 1, 0].item() + sp1 = index_map[pr, 1, 1, 1].item() + c.larray[..., r_loc.item(), st1:sp1] += r[..., st:sp] @ b_lp_data[pr] + else: + c.larray[..., r_loc.item(), :] += ( + r[..., st:sp] @ b_lp_data[pr] + ).squeeze(-2) + + # set the final blocks on the last loop, then adjust for the + # the remainders which were collected in b_rem + if b_rem_locs0.numel(): + c.larray[..., : a_node_rem_s0.shape[-2], :] += ( + a_node_rem_s0 @ b_rem + ) # shouldnt shape[0] always be mB? + del b_lp_data[pr] + + # split la dims 01 + elif a.split == ndim - 2 and b.split == ndim - 1: + # for this case there are no remainders which need to be taken care of + req = {} + b_lp_data = {} + for pr in range(comm.size): + # ibcast data on node first + if comm.rank == pr: + b_lp_data[pr] = b.larray.clone() + else: + b_lp_data[pr] = torch.empty( + (*batch_shape, lshape_map[pr, 1, -2].item(), lshape_map[pr, 1, -1].item()), + dtype=b.dtype.torch_type(), + device=tdev, + ) + # sending a to all nodes for b to operate with + req[pr] = comm.Ibcast(b_lp_data[pr], root=pr) + + # receive the data from the last loop and do the calculation with that + if pr != 0: + req[pr - 1].Wait() + # after receiving the last loop's bcast + st0 = index_map[pr - 1, 0, 0, 0].item() + sp0 = index_map[pr - 1, 0, 0, 1].item() + 1 + st1 = index_map[pr - 1, 1, 1, 0].item() + sp1 = index_map[pr - 1, 1, 1, 1].item() + + c.larray[..., : sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr - 1] + + del b_lp_data[pr - 1] + if pr == comm.size - 1: + req[pr].Wait() + st0 = index_map[pr, 0, 0, 0].item() + sp0 = index_map[pr, 0, 0, 1].item() + 1 + st1 = index_map[pr, 1, 1, 0].item() + sp1 = index_map[pr, 1, 1, 1].item() + c.larray[..., : sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr] + del b_lp_data[pr] + + # split la dims 11 + elif a.split == ndim - 1 and b.split == ndim - 1: + # for this case, a is sent to b + # this is because 'b' has complete columns and the rows of 'a' are split + # locations of the remainders in b + b_rem_locs1 = torch.nonzero(rem_map[:, 1, 1] == 1, as_tuple=False) + a_rem_locs1 = torch.nonzero(rem_map[:, 0, 1] == 1, as_tuple=False) + b_node_rem_s1 = b.larray[..., kB : (kB + 1) * a_rem_locs1.numel() : kB + 1, :nB] + # b_node_rem_s1 -> remainders for a in the + + a_rem = torch.empty( + (*batch_shape, a.lshape[-2], a_rem_locs1.numel()), + dtype=b.dtype.torch_type(), + device=tdev, + ) + # this if/elif/else loop is for the handling of + if comm.rank in b_rem_locs1: + # if b is split in dim1 and the rank has a remainder in this direction + r = b.larray[..., -1].unsqueeze(-1) + r_loc = index_map[comm.rank, 1, 1, 1] - index_map[comm.rank, 1, 1, 0] - 1 else: - b_lp_data[pr] = torch.empty( - (lshape_map[pr, 1, 0].item(), lshape_map[pr, 1, 1].item()), - dtype=b.dtype.torch_type(), - device=tdev, + r = None + r_loc = None + req = {} + a_lp_data = {} + for pr in range(comm.size): + # ibcast data on node first + if a.comm.rank == pr: + a_lp_data[pr] = a.larray.clone() + else: + a_lp_data[pr] = torch.zeros( + (*batch_shape, lshape_map[pr, 0, -2].item(), lshape_map[pr, 0, -1].item()), + dtype=a.dtype.torch_type(), + device=tdev, + ) + # sending a to all nodes for b to operate with + req[pr] = comm.Ibcast(a_lp_data[pr], root=pr) + # receive the data from the last loop and do the calculation with that + if pr != 0: + # after receiving the last loop's bcast + req[pr - 1].Wait() + __mm_c_block_setter( + a_proc=pr - 1, + b_proc=comm.rank, + a_data=a_lp_data[pr - 1], + b_data=b.larray, + b_block_map=b_block_map, + a_block_map=a_block_map, + a_split=1, + b_split=1, + mB=mB, + kB=kB, + nB=nB, + c=c.larray, + ) + # check if there is a remainder on b in the previous node + # this loop is intended to get the remainders of b since it is the one being passed + if pr - 1 in a_rem_locs1: + # takes care of the remainders in b as well as dim0 of a + a_rem[..., pr - 1] = a_lp_data[pr - 1][..., -1] + # this loop is to take care of the remainders in dim1 of B + if b_rem_locs1.nelement() != 0 and r_loc is not None: + st = index_map[pr - 1, 0, 1, 0].item() + sp = index_map[pr - 1, 0, 1, 1].item() + + c.larray[..., r_loc.item()] += ( + a_lp_data[pr - 1] @ r[..., st:sp, :] + ).squeeze(-1) + + del a_lp_data[pr - 1] + + # need to wait if its the last loop, also need to collect the remainders + if pr == b.comm.size - 1: + req[pr].Wait() + __mm_c_block_setter( + a_proc=pr, + b_proc=a.comm.rank, + a_data=a_lp_data[pr], + b_data=b.larray, + b_block_map=b_block_map, + a_block_map=a_block_map, + a_split=1, + b_split=1, + mB=mB, + kB=kB, + nB=nB, + c=c.larray, + ) + # check if there is a remainder on b on the last node (there shouldnt be) + if pr in a_rem_locs1: + # this is to save the data from B required by the remainders from dim1 of A + a_rem[..., pr] = a_lp_data[pr][..., -1] + # this loop is to take care of the remainders in the 0th dimension of A + if b_rem_locs1.nelement() != 0 and r_loc is not None: + st = index_map[pr, 0, 1, 0].item() + sp = index_map[pr, 0, 1, 1].item() + c.larray[..., r_loc.item()] += (a_lp_data[pr] @ r[..., st:sp, :]).squeeze( + -1 + ) + # set the final blocks on the last loop, then adjust for the the remainders which were collected in b_rem + if a_rem_locs1.numel(): + c.larray[..., : b_node_rem_s1.shape[-1]] += a_rem @ b_node_rem_s1 + del a_lp_data[pr] + + # split la dims 10 + elif a.split == ndim - 1 and b.split == ndim - 2: + # todo: this may create the full matrix on evey process, issue #360 + # for this case, only a sum is needed at the end + a_rem_locs1 = torch.nonzero(rem_map[:, 0, 1] == 1, as_tuple=False) + # locations of the remainders in b + b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False) + res = torch.zeros( + (*batch_shape, a.gshape[-2], b.gshape[-1]), dtype=c_type.torch_type(), device=tdev + ) + for i in range(a.lshape[-1] // kB): + res += ( + a.larray[..., :mB, i * kB : i * kB + kB] + @ b.larray[..., i * kB : i * kB + kB, :nB] ) - # sending a to all nodes for b to operate with - req[pr] = b.comm.Ibcast(b_lp_data[pr], root=pr) - - # receive the data from the last loop and do the calculation with that - if pr != 0: - req[pr - 1].Wait() - # after receiving the last loop's bcast - st0 = index_map[pr - 1, 0, 0, 0].item() - sp0 = index_map[pr - 1, 0, 0, 1].item() + 1 - st1 = index_map[pr - 1, 1, 1, 0].item() - sp1 = index_map[pr - 1, 1, 1, 1].item() - - c.larray[: sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr - 1] - - del b_lp_data[pr - 1] - if pr == b.comm.size - 1: - req[pr].Wait() - st0 = index_map[pr, 0, 0, 0].item() - sp0 = index_map[pr, 0, 0, 1].item() + 1 - st1 = index_map[pr, 1, 1, 0].item() - sp1 = index_map[pr, 1, 1, 1].item() - c.larray[: sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr] - del b_lp_data[pr] - if vector_flag: - c = factories.array(c.larray.squeeze(), is_split=0, device=a.device, comm=a.comm) - if gpu_int_flag: - c = og_type(c, device=a.device) - - return c + if a.comm.rank in a_rem_locs1 and b.comm.rank in b_rem_locs0 and kB > 1: + # these Nones are used to change the dims if the full process is not covered + res += a.larray[..., :, -1, None] @ b.larray[..., None, -1, :] + + comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM) + split = a.split if b.gshape[-1] > 1 else ndim - 2 + c = factories.array(res, split=split, device=dev, comm=comm) + + if vector_flag: # squeeze only in the la dimensions + # it could be sensible to resplit/rebalance in case a single node gets the whole vector + split = c.split + if split is not None and split > batch_dim: + split = batch_dim + c_loc = c.larray + if c_loc.numel() == 0: # empty tensor cannot be squeezed + c_loc = torch.zeros((*batch_shape, 0), dtype=c_type.torch_type(), device=tdev) + else: + c_loc = c_loc.squeeze(batch_dim) + if c_loc.ndim >= batch_dim + 2: + c_loc = c_loc.squeeze(batch_dim + 1) + c = factories.array(c_loc, is_split=split, device=dev, comm=comm) - elif split_10_flag: - # todo: this may create the full matrix on evey process, issue #360 - # for this case, only a sum is needed at the end - a_rem_locs1 = torch.nonzero(rem_map[:, 0, 1] == 1, as_tuple=False) - # locations of the remainders in b - b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False) - res = torch.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type.torch_type(), device=tdev) - for i in range(a.lshape[-1] // kB): - res += a.larray[:mB, i * kB : i * kB + kB] @ b.larray[i * kB : i * kB + kB, :nB] - if a.comm.rank in a_rem_locs1 and b.comm.rank in b_rem_locs0 and kB > 1: - # these Nones are used to change the dims if the full process is not covered - res += a.larray[:, -1, None] @ b.larray[None, -1, :] + if gpu_int_flag: + c = og_type(c, device=dev) - a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM) - split = a.split if b.gshape[1] > 1 else 0 - if vector_flag: - split = 0 - res = res.squeeze() - c = factories.array(res, split=split, device=a.device, comm=a.comm) - if gpu_int_flag: - c = og_type(c, device=a.device) - return c + return c def _matmul(self, other): @@ -2002,9 +2066,9 @@ def __mm_c_block_setter( a_block_map : torch.Tensor block map for A b_split : int - split of B + split of B (0 or 1) a_split : int - split of A + split of A (0 or 1) mB : int block size of m kB : int @@ -2020,7 +2084,7 @@ def __mm_c_block_setter( shp_a = a_block_map.shape offset_b = a_proc * shp_a[2] if a_proc != 0 else 0 # offsets are the number of blocks in the multiplication direction on previous nodes - # print(a_block_map[a_proc].shape[0]) + for bl_1_a in ( torch.arange(offset_a, offset_a + shp_b[1], dtype=torch.long, device=c.device) if b_split == 0 @@ -2043,15 +2107,16 @@ def __mm_c_block_setter( # this offset is the same as before but for b a_start1 = int(a_block_map[a_proc, bl_0_a, bl_1_a, 1].item()) a_start0 = int(a_block_map[a_proc, bl_0_a, bl_1_a, 0].item()) - a_block = a_data[a_start0 : a_start0 + mB, a_start1 : a_start1 + kB] + a_block = a_data[..., a_start0 : a_start0 + mB, a_start1 : a_start1 + kB] b_start0 = int(b_block_map[b_proc, bl_0_b, bl_1_b, 0].item()) b_start1 = int(b_block_map[b_proc, bl_0_b, bl_1_b, 1].item()) - b_block = b_data[b_start0 : b_start0 + kB, b_start1 : b_start1 + nB] + b_block = b_data[..., b_start0 : b_start0 + kB, b_start1 : b_start1 + nB] c_start0 = a_start0 c_start1 = b_start1 - c[c_start0 : c_start0 + mB, c_start1 : c_start1 + nB] += a_block @ b_block + + c[..., c_start0 : c_start0 + mB, c_start1 : c_start1 + nB] += a_block @ b_block def transpose(a: DNDarray, axes: Optional[List[int]] = None) -> DNDarray: diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index e4304b1ad4..39fc2583ba 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -563,6 +563,7 @@ def test_matmul(self): ret00 = ht.matmul(a, b) ret_comp = ht.array(a_torch @ b_torch, split=None) + self.assertTrue(ht.equal(ret00, ret_comp)) self.assertIsInstance(ret00, ht.DNDarray) self.assertEqual(ret00.shape, (k,)) @@ -805,13 +806,118 @@ def test_matmul(self): self.assertEqual(ret00.dtype, ht.int64) self.assertEqual(ret00.split, 0) + """ with self.assertRaises(NotImplementedError): a = ht.zeros((3, 3, 3), split=2) b = a.copy() a @ b + """ + with self.assertRaises(TypeError): "T" @ ht.zeros((3, 3, 3)) + # batched, dimension errors + # different number of batch dimensions + with self.assertRaises(ValueError): + a = ht.zeros((3, 3, 3)) + b = ht.zeros((3,)) + ht.matmul(a, b) + # different batch dimension shape + with self.assertRaises(ValueError): + a = ht.zeros((3, 3, 3), split=0) + b = ht.zeros((4, 3, 3), split=0) + ht.matmul(a, b) + # not implemented split + """ + todo + with self.assertRaises(NotImplementedError): + a = ht.zeros((3, 3, 3)) + b = ht.zeros((3, 3, 3)) + ht.matmul(a, b) + """ + + # batched, split batch + n = 11 # number of batches + k = 100 # data dimension size + s1 = ht.arange(n, dtype=ht.int64).reshape((n, 1, 1)) + zeros = ht.zeros((n, 1, k - 1), dtype=ht.int64) + a = ht.concatenate((s1, zeros), 2) + a.resplit_(0) + z1 = ht.ones((n, 1, 1), dtype=ht.int64) + zeros = ht.zeros((n, k - 1, 1), dtype=ht.int64) + b = ht.concatenate((z1, zeros), 1) + b.resplit_(0) + ret_batched = ht.matmul(a, b) + + self.assertTrue(ht.equal(ret_batched, s1)) + self.assertIsInstance(ret_batched, ht.DNDarray) + self.assertEqual( + ret_batched.shape, + ( + n, + 1, + 1, + ), + ) + self.assertEqual(ret_batched.dtype, ht.int64) + self.assertEqual(ret_batched.split, 0) + + # batched + n = 11 # number of batches + k = 100 # data dimension size + m = 100 + + torch.manual_seed(42) + + # integer + at = torch.randint(0, 100, (n, m, k)) + bt = torch.randint(0, 100, (n, k, m)) + ct = at @ bt + + a = ht.factories.asarray(at, copy=True) + b = ht.factories.asarray(bt, copy=True) + c = ht.factories.asarray(ct, copy=True) + + la_splits = (None, 0, 1) + # test all possible la split combinations + for s0 in la_splits: + if s0 is not None: + s0 -= 2 + for s1 in la_splits: + if s1 is not None: + s1 -= 2 + a.resplit_(s0) + b.resplit_(s1) + + ret_batched = ht.matmul(a, b) + + self.assertTrue(ht.equal(ret_batched, c)) + + # float + at = torch.randn((n, m, k)) + bt = torch.randn((n, k, m)) + ct = at @ bt + + a = ht.factories.asarray(at, copy=True) + b = ht.factories.asarray(bt, copy=True) + c = ht.factories.asarray(ct, copy=True) + + for s0 in la_splits: + if s0 is not None: + s0 -= 2 + for s1 in la_splits: + if s1 is not None: + s1 -= 2 + a.resplit_(s0) + b.resplit_(s1) + + ret_batched = ht.matmul(a, b) + # print(f"{s0}{s1}: {ht.max(ht.abs(ret_batched - c)).item()}") + max_diff = ht.max(ht.abs(ret_batched - c)).item() + + # self.assertTrue(ht.allclose(ret_batched, c, 1e-2)) + self.assertTrue(max_diff < 1e-4) + def test_matrix_norm(self): a = ht.arange(9, dtype=ht.float) - 4 b = a.reshape((3, 3)) diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index 1714421efd..6de9e091d8 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -8,6 +8,8 @@ class TestQR(TestCase): def test_qr_split1orNone(self): + ht.random.seed(1234) + for split in [1, None]: for mode in ["reduced", "r"]: # note that split = 1 can be handeled for arbitrary shapes @@ -22,6 +24,13 @@ def test_qr_split1orNone(self): qr = ht.linalg.qr(mat, mode=mode) if mode == "reduced": + allclose = ht.allclose(qr.Q @ qr.R, mat, atol=dtypetol, rtol=dtypetol) + if not allclose: + diff = qr.Q @ qr.R - mat + max_diff = ht.max(diff) + print(f"diff: {diff}") + print(f"max_diff: {max_diff}m") + self.assertTrue( ht.allclose(qr.Q @ qr.R, mat, atol=dtypetol, rtol=dtypetol) )