Skip to content

Broadcasting for linalg functions that accept an axis #617

Closed
@asmeurer

Description

@asmeurer

Two linear algebra functions allow contraction over an arbitrary axis, cross and vecdot.

These APIs currently specify:

x2: Must be compatible with x1 for all non-compute axes (see Broadcasting).

as well as

The compute axis (dimension) must not be broadcasted.

This is ambiguous however when the contracted axis is a dimension that is created by broadcasting, for instance

vecdot(zeros((1, 4, 5)), zeros((4, 5)), axis=0)

Here axis=0 applied to the first dimension would be size 1.

I think this case should be disallowed.

Additionally, axis is ambiguous. It isn't clear if it should refer to the axis before or after broadcasting:

axis (int) – the axis (dimension) of x1 and x2 containing the vectors for which to compute the cross product. Must be an integer on the interval [-N, N), where N is the rank (number of dimensions) of the shape determined according to Broadcasting. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where -1 refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: -1.

This is of particular interest if axis >= 0.

NumPy appears to refer to the axis before broadcasting:

>>> np.cross(np.zeros((3, 1, 2)), np.zeros((3,)), axis=0)
array([[[0., 0.]],

       [[0., 0.]],

       [[0., 0.]]])

In fact, these two arrays aren't strictly broadcast compatible. What NumPy does is move the axis dimension of x1 and x2 to the end of the arrays, then broadcasts x1[..., 0] and x2[..., 0]. Effectively:

x1 = moveaxis(x1, axis, -1)
x2 = moveaxis(x2, axis, -1)
if a.shape[-1] != 3 or b.shape[-1] != 3:
        raise ValueError("incompatible dimensions for cross product")

shape = broadcast(x1[..., 0], x2[..., 0]).shape

In other words, the arrays should be broadcast compatible after removing axis from the shape (and we should have x1.shape[axis] == x2.shape[axis] == 3).

NumPy doesn't have vecdot yet, but it should obviously work the same (the only difference being the contracted axis can have any size in vecdot, not just 3, and unlike cross, in vecdot the contracted axis is removed from the resulting shape). torch.linalg.cross doesn't appear to support any broadcasting.

My implementations of vecdot in numpy.array_api and array-api-compat have been using the idea that axis refers to the axis broadcasting and allowing an added broadcasted axis. But I think this should be changed to work like np.cross. The numpy.array_api and array_api_compat.numpy cross implementations just reuse np.cross and therefore use those semantics (I didn't realize til now that we weren't actually testing any broadcasting rules for cross in the test suite).

This was discussed at data-apis/array-api-compat#35 (comment) (CC @lezcano).

Finally, note that tensordot doesn't have this issue because the axes are specified for each array separately.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions