You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
During testing, I found that PyTorch's matmul does not support batched matrix-vector multiplication. If an argument has more than two dimensions, the shapes, including the batch dimensions, are simply broadcasted and it is treated as a batched matrix-matrix multiplication. As a result, batched matrix-vector multiplication (#1261) does not work in all cases.
Since this is not supported by PyTorch itself, and since one can just use unsqueeze/expand_dims to handle this, I would suggest removing this functionality from Heat, and maybe raising a NotImplementedError.
What happened?
During testing, I found that PyTorch's matmul does not support batched matrix-vector multiplication. If an argument has more than two dimensions, the shapes, including the batch dimensions, are simply broadcasted and it is treated as a batched matrix-matrix multiplication. As a result, batched matrix-vector multiplication (#1261) does not work in all cases.
Since this is not supported by PyTorch itself, and since one can just use
unsqueeze/expand_dims
to handle this, I would suggest removing this functionality from Heat, and maybe raising a NotImplementedError.Relevant link: Matrix-vector multiply (handling batched data) - PyTorch Forums
Code snippet triggering the error
No response
Error message or erroneous outcome
No response
Version
main (development branch)
Python version
None
PyTorch version
None
MPI version
No response
The text was updated successfully, but these errors were encountered: