Description
Two linear algebra functions allow contraction over an arbitrary axis, cross and vecdot.
These APIs currently specify:
x2: Must be compatible with x1 for all non-compute axes (see Broadcasting).
as well as
The compute axis (dimension) must not be broadcasted.
This is ambiguous however when the contracted axis is a dimension that is created by broadcasting, for instance
vecdot(zeros((1, 4, 5)), zeros((4, 5)), axis=0)
Here axis=0 applied to the first dimension would be size 1
.
I think this case should be disallowed.
Additionally, axis
is ambiguous. It isn't clear if it should refer to the axis before or after broadcasting:
axis (int) – the axis (dimension) of x1 and x2 containing the vectors for which to compute the cross product. Must be an integer on the interval [-N, N), where N is the rank (number of dimensions) of the shape determined according to Broadcasting. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where -1 refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: -1.
This is of particular interest if axis >= 0
.
NumPy appears to refer to the axis before broadcasting:
>>> np.cross(np.zeros((3, 1, 2)), np.zeros((3,)), axis=0)
array([[[0., 0.]],
[[0., 0.]],
[[0., 0.]]])
In fact, these two arrays aren't strictly broadcast compatible. What NumPy does is move the axis
dimension of x1
and x2
to the end of the arrays, then broadcasts x1[..., 0]
and x2[..., 0]
. Effectively:
x1 = moveaxis(x1, axis, -1)
x2 = moveaxis(x2, axis, -1)
if a.shape[-1] != 3 or b.shape[-1] != 3:
raise ValueError("incompatible dimensions for cross product")
shape = broadcast(x1[..., 0], x2[..., 0]).shape
In other words, the arrays should be broadcast compatible after removing axis
from the shape (and we should have x1.shape[axis] == x2.shape[axis] == 3
).
NumPy doesn't have vecdot
yet, but it should obviously work the same (the only difference being the contracted axis can have any size in vecdot, not just 3, and unlike cross, in vecdot the contracted axis is removed from the resulting shape). torch.linalg.cross doesn't appear to support any broadcasting.
My implementations of vecdot
in numpy.array_api
and array-api-compat
have been using the idea that axis
refers to the axis broadcasting and allowing an added broadcasted axis. But I think this should be changed to work like np.cross
. The numpy.array_api
and array_api_compat.numpy
cross
implementations just reuse np.cross
and therefore use those semantics (I didn't realize til now that we weren't actually testing any broadcasting rules for cross
in the test suite).
This was discussed at data-apis/array-api-compat#35 (comment) (CC @lezcano).
Finally, note that tensordot doesn't have this issue because the axes are specified for each array separately.