Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched matrix multiplication. #1261

Merged
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
67943f9
first implementation of the minimal solution
FOsterfeld Nov 7, 2023
1c60823
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Nov 8, 2023
0ddba6c
access b.gshape[-2] only if input is not batched
FOsterfeld Nov 8, 2023
d60c0ca
fixed batched condition
FOsterfeld Nov 21, 2023
a60d3ac
throw a NotImplementedError for wrong split dimension on batched matmul
FOsterfeld Nov 21, 2023
e16366b
fixed dimension condition
FOsterfeld Nov 21, 2023
7644dd4
added test for batched matmul with split dimension being a batch dime…
FOsterfeld Nov 21, 2023
5d34282
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Nov 21, 2023
095ccc5
fixed condition for different batch dimensions
FOsterfeld Nov 21, 2023
e55f8b8
added some tests for correctly thrown errors
FOsterfeld Nov 21, 2023
a9ae2bf
fixed test for batched matmul on gpu
FOsterfeld Nov 21, 2023
06913c7
test for batched matmul on gpu
FOsterfeld Nov 21, 2023
ba60c82
remove unnecessary test with device=gpu
FOsterfeld Nov 22, 2023
8d95ec1
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Nov 23, 2023
1c2a939
batched matmul with split==None for both matrices
FOsterfeld Nov 28, 2023
a79e42d
implemented batched matmul for case split 00
FOsterfeld Dec 14, 2023
980d0ec
implemented batched matmul for case split 01
FOsterfeld Dec 27, 2023
a44c6b2
implemented batched matmul for case split 11
FOsterfeld Dec 27, 2023
b506a66
cleaned up code to return the result
FOsterfeld Dec 27, 2023
f76e973
added tests for the batched matmul
FOsterfeld Dec 27, 2023
9733a28
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jan 2, 2024
0008a3f
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jan 8, 2024
18cdcf1
added batched matmul tests for float values
FOsterfeld Jan 9, 2024
4e49aa5
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jan 22, 2024
bb0856b
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 1, 2024
0d37ff4
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 5, 2024
8531106
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 6, 2024
c911e45
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 14, 2024
5a2ad15
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Feb 20, 2024
e804c2c
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Feb 29, 2024
e5ff10b
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Mar 7, 2024
da9b0e3
improved exception throwing: error message when only one matrix has s…
FOsterfeld Mar 19, 2024
f3e0ced
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jun 3, 2024
2719f4d
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Jun 20, 2024
0f4c677
warn against the inefficient split cases in the matmul docstring
FOsterfeld Jul 4, 2024
96121ee
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Jul 5, 2024
5e5eea3
Update basics.py
mrfh92 Jul 5, 2024
5933e48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
0a9795f
Update basics.py
mrfh92 Jul 5, 2024
a2f1cc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
816cd85
fixed style complaints
Jul 5, 2024
40aa455
Apply suggestions from code review
FOsterfeld Jul 17, 2024
ea18fa1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
2222bb7
fixed documentation
FOsterfeld Jul 17, 2024
35ff132
Merge branch 'features/1104-Implement_consistent_linear_algebra_for_a…
FOsterfeld Jul 17, 2024
fd44be9
updated matmul tests for new batch behavior
FOsterfeld Aug 16, 2024
de73236
restructured code to remove code duplication of batched and unbatched…
FOsterfeld Aug 16, 2024
98a4134
generalized the split case None-None to batched matrices
FOsterfeld Aug 17, 2024
9e7d0f0
simplified the cases where not both matrices are split in la dimensions
FOsterfeld Aug 17, 2024
c95be79
generalized the None splits for batched matrices
FOsterfeld Aug 17, 2024
8dffcf5
removed unnecessary import
FOsterfeld Aug 17, 2024
41e203c
updated docstring
FOsterfeld Aug 17, 2024
3886942
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Aug 17, 2024
9f9462c
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Aug 20, 2024
3f15690
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
FOsterfeld Aug 28, 2024
ca066b2
initialize random generator
FOsterfeld Aug 30, 2024
398b27e
refactored code for None splits
FOsterfeld Aug 30, 2024
6f92537
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mrfh92 Sep 2, 2024
fc97280
Merge branch 'main' into features/1104-Implement_consistent_linear_al…
mtar Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improved exception throwing: error message when only one matrix has s…
…plit None
  • Loading branch information
FOsterfeld committed Mar 19, 2024
commit da9b0e3d14f8aa31411fb468da09b1dbfbc3e659
20 changes: 14 additions & 6 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,21 +538,29 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
tdev = dev.torch_device
batch_shape = a.gshape[:batch_dim]

if (
a.split is None or b.split is None and a.split != b.split
FOsterfeld marked this conversation as resolved.
Show resolved Hide resolved
): # only one matrix has split None
raise NotImplementedError("Only one matrix has split None!")

if (
a.split is not None
and (a.split < batch_dim or b.split < batch_dim)
and a.split != b.split
): # not the same batch axis for split
raise NotImplementedError(
"If one matrix is split along a batch axis, both have to be split along that axis!"
FOsterfeld marked this conversation as resolved.
Show resolved Hide resolved
)

# la dimension not split -> torch
if a.split is None and b.split is None or a.split < batch_dim and b.split < batch_dim:
if a.split != b.split:
raise NotImplementedError("Split axes are different batch axes!")

ret = factories.array(
torch.matmul(a.larray, b.larray), is_split=a.split, device=a.device, comm=a.comm
)
if gpu_int_flag:
ret = og_type(ret, device=dev)
return ret

if a.split is None or b.split is None: # only one matrix has split None
raise NotImplementedError

# block sizes dont need to be the same. they just need the same inner dimension (kB)
kB = 0 # redundant?
rem_a, rem_b = 0, 0
Expand Down