Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched matrix multiplication. #1261

Merged
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
67943f9
first implementation of the minimal solution
FOsterfeld Nov 7, 2023
1c60823
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Nov 8, 2023
0ddba6c
access b.gshape[-2] only if input is not batched
FOsterfeld Nov 8, 2023
d60c0ca
fixed batched condition
FOsterfeld Nov 21, 2023
a60d3ac
throw a NotImplementedError for wrong split dimension on batched matmul
FOsterfeld Nov 21, 2023
e16366b
fixed dimension condition
FOsterfeld Nov 21, 2023
7644dd4
added test for batched matmul with split dimension being a batch dime…
FOsterfeld Nov 21, 2023
5d34282
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Nov 21, 2023
095ccc5
fixed condition for different batch dimensions
FOsterfeld Nov 21, 2023
e55f8b8
added some tests for correctly thrown errors
FOsterfeld Nov 21, 2023
a9ae2bf
fixed test for batched matmul on gpu
FOsterfeld Nov 21, 2023
06913c7
test for batched matmul on gpu
FOsterfeld Nov 21, 2023
ba60c82
remove unnecessary test with device=gpu
FOsterfeld Nov 22, 2023
8d95ec1
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Nov 23, 2023
1c2a939
batched matmul with split==None for both matrices
FOsterfeld Nov 28, 2023
a79e42d
implemented batched matmul for case split 00
FOsterfeld Dec 14, 2023
980d0ec
implemented batched matmul for case split 01
FOsterfeld Dec 27, 2023
a44c6b2
implemented batched matmul for case split 11
FOsterfeld Dec 27, 2023
b506a66
cleaned up code to return the result
FOsterfeld Dec 27, 2023
f76e973
added tests for the batched matmul
FOsterfeld Dec 27, 2023
9733a28
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jan 2, 2024
0008a3f
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jan 8, 2024
18cdcf1
added batched matmul tests for float values
FOsterfeld Jan 9, 2024
4e49aa5
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jan 22, 2024
bb0856b
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 1, 2024
0d37ff4
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 5, 2024
8531106
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 6, 2024
c911e45
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 14, 2024
5a2ad15
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 20, 2024
e804c2c
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Feb 29, 2024
e5ff10b
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Mar 7, 2024
da9b0e3
improved exception throwing: error message when only one matrix has s…
FOsterfeld Mar 19, 2024
f3e0ced
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jun 3, 2024
2719f4d
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Jun 20, 2024
0f4c677
warn against the inefficient split cases in the matmul docstring
FOsterfeld Jul 4, 2024
96121ee
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jul 5, 2024
5e5eea3
Update basics.py
mrfh92 Jul 5, 2024
5933e48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
0a9795f
Update basics.py
mrfh92 Jul 5, 2024
a2f1cc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
816cd85
fixed style complaints
Jul 5, 2024
40aa455
Apply suggestions from code review
FOsterfeld Jul 17, 2024
ea18fa1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
2222bb7
fixed documentation
FOsterfeld Jul 17, 2024
35ff132
Merge branch 'features/1104-Implement_consistent_linear_algebra_for_a…
FOsterfeld Jul 17, 2024
fd44be9
updated matmul tests for new batch behavior
FOsterfeld Aug 16, 2024
de73236
restructured code to remove code duplication of batched and unbatched…
FOsterfeld Aug 16, 2024
98a4134
generalized the split case None-None to batched matrices
FOsterfeld Aug 17, 2024
9e7d0f0
simplified the cases where not both matrices are split in la dimensions
FOsterfeld Aug 17, 2024
c95be79
generalized the None splits for batched matrices
FOsterfeld Aug 17, 2024
8dffcf5
removed unnecessary import
FOsterfeld Aug 17, 2024
41e203c
updated docstring
FOsterfeld Aug 17, 2024
3886942
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Aug 17, 2024
9f9462c
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Aug 20, 2024
3f15690
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Aug 28, 2024
ca066b2
initialize random generator
FOsterfeld Aug 30, 2024
398b27e
refactored code for None splits
FOsterfeld Aug 30, 2024
6f92537
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Sep 2, 2024
fc97280
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mtar Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implemented batched matmul for case split 00
  • Loading branch information
FOsterfeld committed Dec 14, 2023
commit a79e42d8ab2fcb1ae4580aac247b9f0814ba69de
334 changes: 325 additions & 9 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,12 +543,327 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
ret = og_type(ret, device=a.device)
return ret

if a.split is None or b.split is None:
if a.split is None or b.split is None: # only one matrix has split None
raise NotImplementedError

comm = a.comm
ndim = len(a.gshape)
dev = a.device
tdev = dev.torch_device
batch_shape = a.gshape[:batch_dim]

# block sizes dont need to be the same. they just need the same inner dimension (kB)
kB = 0 # redundant?
rem_a, rem_b = 0, 0
if a.split == ndim - 1 and b.split == ndim - 2: # split 10
# if the split direction is the last dim in a and the first dim in b
# the max inner dim (kB) is the min value from the result of the integer division
# of the last dim of a/world size and the first dim of b/world size
kB = min(
[a.gshape[-1] // comm.size, b.gshape[-2] // comm.size]
) # a.gshape[-1] == b.gshape[-2]
elif a.split == ndim - 2 and b.split == ndim - 1: # split 01
kB = a.gshape[-1]
elif a.split == ndim - 1: # split 11
kB = a.gshape[-1] // comm.size
elif b.split == ndim - 2: # split 00
kB = b.gshape[-2] // comm.size
kB = min(
kB, a.gshape[-1]
) # shouldnt this always be kB and be the same as for split 11?

if a.lshape[-1] % kB != 0 or (
kB == 1 and a.lshape[-1] != 1
): # does kb == 1 imply a.lshape[-1] > 1?
rem_a = 1
if b.lshape[-2] % kB != 0 or (kB == 1 and b.lshape[-2] != 1):
rem_b = 1

# get the lshape map to determine what needs to be sent where as well as M and N
# lshape map dims -> {node, a=0 | b=1, lshape}
lshape_map = torch.zeros((comm.size, 2, ndim), dtype=int, device=tdev)
lshape_map[comm.rank, 0, :] = torch.tensor(a.lshape, device=tdev)
lshape_map[comm.rank, 1, :] = torch.tensor(b.lshape, device=tdev)
comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM)

# find mB (first blocking dim for a) and nB (2nd blocking dim for b)
mB = lshape_map[:, 0, -2].min().item() # smallest number of local rows of a on a node
nB = lshape_map[:, 1, -1].min().item() # smallest number of local columns of b on a node

# check for remaining dims in the outside dimensions
rem_a_out, rem_b_out = 0, 0
if a.lshape[-2] % mB != 0 or (kB == 1 and a.lshape[-2] != 1):
rem_a_out = 1
if b.lshape[-1] % nB != 0 or (kB == 1 and b.lshape[-1] != 1):
rem_b_out = 1

# get the flags from all processes
# rem_map dims guide -> {process number, a/b (0/1), dim0/dim1 (0/1), True/False (1/0)
# if there is a remainder in this dimension
rem_map = torch.zeros((comm.size, 2, 2))
rem_map[comm.rank, 0, :] = torch.tensor((rem_a_out, rem_a), device=tdev)
rem_map[comm.rank, 1, :] = torch.tensor((rem_b, rem_b_out), device=tdev)
rem_map_comm = comm.Iallreduce(MPI.IN_PLACE, rem_map, MPI.SUM)

# index_map dims guide -> {process number, a=0/b=1, relevant 1st index, 2nd index}
index_map = torch.zeros((comm.size, 2, 2, 2), dtype=int, device=tdev)
a_idx = comm.chunk(a.shape, a.split)[2]
index_map[comm.rank, 0, 0] = torch.tensor((a_idx[-2].start, a_idx[-2].stop), device=tdev)
index_map[comm.rank, 0, 1] = torch.tensor((a_idx[-1].start, a_idx[-1].stop), device=tdev)
b_idx = comm.chunk(b.shape, b.split)[2]
index_map[comm.rank, 1, 0] = torch.tensor((b_idx[-2].start, b_idx[-2].stop), device=tdev)
index_map[comm.rank, 1, 1] = torch.tensor((b_idx[-1].start, b_idx[-1].stop), device=tdev)
index_map_comm = comm.Iallreduce(MPI.IN_PLACE, index_map, MPI.SUM)

# output: c = a @ b
# for the communication scheme, the output array needs to be created
c_shape = (*batch_shape, a.gshape[-2], b.gshape[-1])
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=dev, comm=comm)

# get the index map for c
c_index_map = factories.zeros((c.comm.size, 2, 2), device=dev, comm=comm)
c_idx = comm.chunk(c.shape, c.split)[2]
c_index_map[comm.rank, 0, :] = (c_idx[-2].start, c_idx[-2].stop)
c_index_map[comm.rank, 1, :] = (c_idx[-1].start, c_idx[-1].stop)
c_index_map_comm = comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM)

if a.split == ndim - 2:
a_block_map = torch.zeros(
(comm.size, a.shape[-2] // mB // comm.size, a.shape[-1] // kB, 2),
dtype=torch.int,
device=tdev,
)
elif a.split == ndim - 1: # else should be equivalent at this point
a_block_map = torch.zeros(
(comm.size, a.shape[-2] // mB, a.shape[-1] // kB // comm.size, 2),
dtype=torch.int,
device=tdev,
)
# units-> [process, dim0 block number, dim1 block number, start coord] **indices are local

# below is to handle the edge case where there is only one element in one dimension of a
a_d0_1s_flag, a_d1_1s_flag = False, False
if any(lshape_map[:, 0, :][:, -2] == 1):
a_d0_1s_flag = True
if any(lshape_map[:, 0, :][:, -1] == 1):
a_d1_1s_flag = True

index_map_comm.Wait()
for pr in range(comm.size):
start0 = index_map[pr, 0, 0, 0].item()
stop0 = index_map[pr, 0, 0, 1].item()
start1 = index_map[pr, 0, 1, 0].item()
stop1 = index_map[pr, 0, 1, 1].item()

# maybe we could use torch.arange instead of this nested loop
for dim0 in range(
(stop0 - start0) // mB // comm.size if a_d0_1s_flag else (stop0 - start0) // mB
):
# loop over the number of blocks in the 0th dimension
for dim1 in range(
(stop1 - start1) // kB // comm.size if a_d1_1s_flag else (stop1 - start1) // kB
):
# loop over the number of blocks in the 1st dimension
a_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * mB, dim1 * kB), dtype=torch.int, device=tdev
)
rem_map_comm.Wait()

if b.split == ndim - 2:
# the blocks are shifted in the 2nd dimension of A for as many remainders
# there are between the blocks in the first dim of B
cnt = 0
for r in rem_map[:, 1, 0]:
if r.item():
cnt += 1
# why increment by exactly 1? what can we assume about the lshapes on different nodes?
# can the sizes in the split dimension differ by more than 1?
a_block_map[:, :, cnt:, 1] += 1

b_block_map = torch.zeros(
(comm.size, b.shape[-2] // kB // comm.size, b.shape[-1] // nB, 2),
dtype=torch.int,
device=tdev,
)
else: # b split 1
b_block_map = torch.zeros(
(comm.size, b.shape[-2] // kB, b.shape[-1] // nB // comm.size, 2),
dtype=torch.int,
device=tdev,
)
# units-> [process, dim0 block number, dim1 block number, start coord] **indices are local

# below is to handle the edge case where there is only one element in one dimension of b
b_d0_1s_flag, b_d1_1s_flag = False, False
if any(lshape_map[:, 1, :][:, -2] == 1):
b_d0_1s_flag = True
if any(lshape_map[:, 1, :][:, -1] == 1):
b_d1_1s_flag = True

for pr in range(b.comm.size):
start0 = index_map[pr, 1, 0, 0].item()
stop0 = index_map[pr, 1, 0, 1].item()
start1 = index_map[pr, 1, 1, 0].item()
stop1 = index_map[pr, 1, 1, 1].item()

# loop over the number of blocks in the 0th dimension
for dim0 in range(
(stop0 - start0) // kB // b.comm.size if b_d0_1s_flag else (stop0 - start0) // kB
):
# loop over the number of blocks in the 1st dimension
for dim1 in range(
(stop1 - start1) // nB // b.comm.size
if b_d1_1s_flag
else (stop1 - start1) // nB
):
b_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * kB, dim1 * nB), dtype=torch.int, device=tdev
)

if a.split == ndim - 1:
cnt = 0
# this loop will push the blocks in B to adjust for the remainders in A
for r in rem_map[:, 0, 1]:
if r.item():
cnt += 1
b_block_map[:, cnt:, :, 0] += 1

# work loop: loop over all processes (also will incorporate the remainder calculations)
c_index_map_comm.Wait()

# split la dims 00
if a.split == len(a.gshape) - 2 and b.split == len(b.gshape) - 2:
raise NotImplementedError
if a.split == ndim - 2 and b.split == ndim - 2:
# need to send b here and not a
# the rows on 'a' are complete, and the columns of 'b' are split
# locations of the remainders in b
b_rem_locs0 = torch.nonzero(rem_map[:, 1, 0] == 1, as_tuple=False)
a_rem_locs0 = torch.nonzero(rem_map[:, 0, 0] == 1, as_tuple=False)
# remainders for a in the
a_node_rem_s0 = a.larray[..., :mB, kB : (kB + 1) * b_rem_locs0.numel() : kB + 1]
b_rem = torch.empty(
(*batch_shape, b_rem_locs0.numel(), b.lshape[-1]),
dtype=a.dtype.torch_type(),
device=tdev,
)

# this if/elif/else loop is for the handling of
if comm.rank in a_rem_locs0:
# if A is split in dim0 and the rank has a remainder in this direction
r = a.larray[..., -1, :].unsqueeze(-2)
# can we not just set r_loc = -1 instead?
r_loc = index_map[comm.rank, 0, 0, 1] - index_map[comm.rank, 0, 0, 0] - 1
else:
r = None
r_loc = None

req = {}
b_lp_data = {}
for pr in range(comm.size):
# ibcast data on node first
if comm.rank == pr:
b_lp_data[pr] = b.larray.clone()
else:
b_lp_data[pr] = torch.zeros(
(*batch_shape, lshape_map[pr, 1, -2].item(), lshape_map[pr, 1, -1].item()),
dtype=b.dtype.torch_type(),
device=tdev,
)

# sending a to all nodes for b to operate with
req[pr] = comm.Ibcast(b_lp_data[pr], root=pr)

# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].Wait()
# after receiving the last loop's bcast
__mm_c_block_setter(
b_proc=pr - 1,
a_proc=comm.rank,
a_data=a.larray,
b_data=b_lp_data[pr - 1],
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=0,
a_split=0,
mB=mB,
kB=kB,
nB=nB,
c=c.larray,
)

# check if there is a remainder on b in the previous node
# this loop is intended to get the remainders of b since it is the one being passed
if pr - 1 in b_rem_locs0:
# takes care of the remainders in b as well as dim0 of a
b_rem[..., pr - 1, :] = b_lp_data[pr - 1][..., -1, :]

# this loop is to take care of the remainders in dim0 of a
if a_rem_locs0.nelement() != 0 and r_loc is not None:
st = index_map[pr - 1, 1, 0, 0].item()
sp = index_map[pr - 1, 1, 0, 1].item()

c.larray[..., r_loc.item(), :] += (
r[..., st:sp] @ b_lp_data[pr - 1]
).squeeze(-2)
del b_lp_data[pr - 1]

# need to wait if its the last loop, also need to collect the remainders
if pr == comm.size - 1:
req[pr].Wait()
__mm_c_block_setter(
b_proc=pr,
a_proc=comm.rank,
a_data=a.larray,
b_data=b_lp_data[pr],
b_block_map=b_block_map,
a_block_map=a_block_map,
b_split=0,
a_split=0,
mB=mB,
kB=kB,
nB=nB,
c=c.larray,
)
# check if there is a remainder on b on the last node (there shouldnt be)
if pr in b_rem_locs0:
# this is to save the data from B required by the remainders from dim1 of A
b_rem[..., pr, :] = b_lp_data[pr][..., -1, :]

# this loop is to take care of the remainders in the 0th dimension of A
if a_rem_locs0.nelement() != 0 and r_loc is not None:
st = index_map[pr, 1, 0, 0].item()
sp = index_map[pr, 1, 0, 1].item() # linear algebra dimension 0/1

# code not reachable?
# if split_01_flag:
if False:
st1 = index_map[pr, 1, 1, 0].item()
sp1 = index_map[pr, 1, 1, 1].item()
c.larray[..., r_loc.item(), st1:sp1] += r[..., st:sp] @ b_lp_data[pr]
else:
c.larray[..., r_loc.item(), :] += (
r[..., st:sp] @ b_lp_data[pr]
).squeeze(-2)

# set the final blocks on the last loop, then adjust for the
# the remainders which were collected in b_rem
if b_rem_locs0.numel():
c.larray[..., : a_node_rem_s0.shape[-2], :] += (
a_node_rem_s0 @ b_rem
) # shouldnt shape[0] always be mB?
del b_lp_data[pr]

"""
if vector_flag:
c_loc = c.larray.squeeze()
if c_loc.nelement() == 1:
c_loc = torch.tensor(c_loc, device=tdev)

c = factories.array(c_loc, is_split=0, device=a.device, comm=a.comm)
"""
if gpu_int_flag:
c = og_type(c, device=a.device)
return c

if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
Expand Down Expand Up @@ -2035,9 +2350,9 @@ def __mm_c_block_setter(
a_block_map : torch.Tensor
block map for A
b_split : int
split of B
split of B (0 or 1)
a_split : int
split of A
split of A (0 or 1)
mB : int
block size of m
kB : int
Expand All @@ -2053,7 +2368,7 @@ def __mm_c_block_setter(
shp_a = a_block_map.shape
offset_b = a_proc * shp_a[2] if a_proc != 0 else 0
# offsets are the number of blocks in the multiplication direction on previous nodes
# print(a_block_map[a_proc].shape[0])

for bl_1_a in (
torch.arange(offset_a, offset_a + shp_b[1], dtype=torch.long, device=c.device)
if b_split == 0
Expand All @@ -2076,15 +2391,16 @@ def __mm_c_block_setter(
# this offset is the same as before but for b
a_start1 = int(a_block_map[a_proc, bl_0_a, bl_1_a, 1].item())
a_start0 = int(a_block_map[a_proc, bl_0_a, bl_1_a, 0].item())
a_block = a_data[a_start0 : a_start0 + mB, a_start1 : a_start1 + kB]
a_block = a_data[..., a_start0 : a_start0 + mB, a_start1 : a_start1 + kB]

b_start0 = int(b_block_map[b_proc, bl_0_b, bl_1_b, 0].item())
b_start1 = int(b_block_map[b_proc, bl_0_b, bl_1_b, 1].item())
b_block = b_data[b_start0 : b_start0 + kB, b_start1 : b_start1 + nB]
b_block = b_data[..., b_start0 : b_start0 + kB, b_start1 : b_start1 + nB]

c_start0 = a_start0
c_start1 = b_start1
c[c_start0 : c_start0 + mB, c_start1 : c_start1 + nB] += a_block @ b_block

c[..., c_start0 : c_start0 + mB, c_start1 : c_start1 + nB] += a_block @ b_block


def transpose(a: DNDarray, axes: Optional[List[int]] = None) -> DNDarray:
Expand Down
Loading