-
Notifications
You must be signed in to change notification settings - Fork 536
Closed
Description
The following lines raise a warning in ot.lp.wasserstein_1d with 1-d inputs, with torch backend.
Lines 41 to 42 in 3a53dff
| cws = cws.T.contiguous() | |
| qs = qs.T.contiguous() |
/mnt/home/frozet/.venvs/lola/lib/python3.11/site-packages/ot/lp/solver_1d.py:41:
UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse
their shape is deprecated and it will throw an error in a future release. Consider
`x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))`
to reverse the dimensions of a tensor. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3683.)
cws = cws.T.contiguous()
This can be solved by using cws = cws.movedim(0, -1).contiguous() instead. Note that, unlike .T, movedim(0, -1) will work with any number of dimensions (even greater than 2).
Metadata
Metadata
Assignees
Labels
No labels