Skip to content

Commit 379b5b7

Browse files
Illviljankeewismax-sixty
authored
Add support for cross product (#5365)
* Add support for cross * Update test_computation.py * Update computation.py * Update computation.py * Update test_computation.py * Update test_computation.py * Update test_computation.py * add more tests * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * spatial_dim to dim * Update computation.py * use pad instead of concat * copy paste np.cross intro * Get last dim for each array, which is more inline with np.cross * examples in docs * Update computation.py * more doc examples * single dim required, tranpose after apply_ufunc * add dims to tests * Update computation.py * reduce code * support xr.Variable * Update computation.py * Update computation.py * reduce code * docstring explanations * Use same terms * docstring formatting * reduce code * add tests for dask * simplify check, align used variables * trim down tests * Update computation.py * simplify code * Add type hints * less type hints * Update computation.py * undo type hints * Update computation.py * Add support for datasets * determine dtype with np.result_type * test datasets, daskify the inputs not the results * rechunk padded values, handle 1 sized datasets * expand only unique dims, squeeze out dims in tests * rechunk along the dim * Attempt typing again * Update __init__.py * Update computation.py * Update computation.py * test fixing type in to_stacked_array * test fixing to_stacked_array * small is large * Update computation.py * Update xarray/core/computation.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * obfuscate variable_dim some * Update computation.py * undo to_stacked_array changes * test sample_dims typing * to_stacked_array fixes * add reindex_like check * Update computation.py * Update computation.py * Update computation.py * test forcing int type in chunk() * Update computation.py * test collection in to_stacked_array * Update computation.py * Update computation.py * Update computation.py * Update computation.py * Update computation.py * whats new and api.rst * Update whats-new.rst * Output as dataset if any input is a dataset * Simplify the if terms instead of using pass. * Update computation.py * Remove support for datasets * Update computation.py * Add some typing to test. * doctest fix * lint * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * Update computation.py * Update computation.py * Update computation.py * Update computation.py * Update computation.py * Can't narrow types with old type Seems using bounds in typevar makes it impossible to narrow the type using isinstance checks. * dim now keyword only * use all_dims in transpose * if in transpose indeed needed if a and b has size 2 it's needed. * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * Update xarray/core/computation.py Co-authored-by: keewis <keewis@users.noreply.github.com> * Update computation.py * Update computation.py * add todo comments * Update whats-new.rst Co-authored-by: keewis <keewis@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
1 parent 92ac89f commit 379b5b7

File tree

5 files changed

+330
-1
lines changed

5 files changed

+330
-1
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Top-level functions
3232
ones_like
3333
cov
3434
corr
35+
cross
3536
dot
3637
polyval
3738
map_blocks

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ v0.21.0 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24+
- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
25+
By `Jimmy Westling <https://github.com/illviljan>`_.
2426

2527

2628
Breaking changes

xarray/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
from .core.alignment import align, broadcast
1717
from .core.combine import combine_by_coords, combine_nested
1818
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
19-
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
19+
from .core.computation import (
20+
apply_ufunc,
21+
corr,
22+
cov,
23+
cross,
24+
dot,
25+
polyval,
26+
unify_chunks,
27+
where,
28+
)
2029
from .core.concat import concat
2130
from .core.dataarray import DataArray
2231
from .core.dataset import Dataset
@@ -60,6 +69,7 @@
6069
"dot",
6170
"cov",
6271
"corr",
72+
"cross",
6373
"full_like",
6474
"get_options",
6575
"infer_freq",

xarray/core/computation.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
if TYPE_CHECKING:
3838
from .coordinates import Coordinates
39+
from .dataarray import DataArray
3940
from .dataset import Dataset
4041
from .types import T_Xarray
4142

@@ -1373,6 +1374,214 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
13731374
return corr
13741375

13751376

1377+
def cross(
1378+
a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable
1379+
) -> Union[DataArray, Variable]:
1380+
"""
1381+
Compute the cross product of two (arrays of) vectors.
1382+
1383+
The cross product of `a` and `b` in :math:`R^3` is a vector
1384+
perpendicular to both `a` and `b`. The vectors in `a` and `b` are
1385+
defined by the values along the dimension `dim` and can have sizes
1386+
1, 2 or 3. Where the size of either `a` or `b` is
1387+
1 or 2, the remaining components of the input vector is assumed to
1388+
be zero and the cross product calculated accordingly. In cases where
1389+
both input vectors have dimension 2, the z-component of the cross
1390+
product is returned.
1391+
1392+
Parameters
1393+
----------
1394+
a, b : DataArray or Variable
1395+
Components of the first and second vector(s).
1396+
dim : hashable
1397+
The dimension along which the cross product will be computed.
1398+
Must be available in both vectors.
1399+
1400+
Examples
1401+
--------
1402+
Vector cross-product with 3 dimensions:
1403+
1404+
>>> a = xr.DataArray([1, 2, 3])
1405+
>>> b = xr.DataArray([4, 5, 6])
1406+
>>> xr.cross(a, b, dim="dim_0")
1407+
<xarray.DataArray (dim_0: 3)>
1408+
array([-3, 6, -3])
1409+
Dimensions without coordinates: dim_0
1410+
1411+
Vector cross-product with 2 dimensions, returns in the perpendicular
1412+
direction:
1413+
1414+
>>> a = xr.DataArray([1, 2])
1415+
>>> b = xr.DataArray([4, 5])
1416+
>>> xr.cross(a, b, dim="dim_0")
1417+
<xarray.DataArray ()>
1418+
array(-3)
1419+
1420+
Vector cross-product with 3 dimensions but zeros at the last axis
1421+
yields the same results as with 2 dimensions:
1422+
1423+
>>> a = xr.DataArray([1, 2, 0])
1424+
>>> b = xr.DataArray([4, 5, 0])
1425+
>>> xr.cross(a, b, dim="dim_0")
1426+
<xarray.DataArray (dim_0: 3)>
1427+
array([ 0, 0, -3])
1428+
Dimensions without coordinates: dim_0
1429+
1430+
One vector with dimension 2:
1431+
1432+
>>> a = xr.DataArray(
1433+
... [1, 2],
1434+
... dims=["cartesian"],
1435+
... coords=dict(cartesian=(["cartesian"], ["x", "y"])),
1436+
... )
1437+
>>> b = xr.DataArray(
1438+
... [4, 5, 6],
1439+
... dims=["cartesian"],
1440+
... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
1441+
... )
1442+
>>> xr.cross(a, b, dim="cartesian")
1443+
<xarray.DataArray (cartesian: 3)>
1444+
array([12, -6, -3])
1445+
Coordinates:
1446+
* cartesian (cartesian) <U1 'x' 'y' 'z'
1447+
1448+
One vector with dimension 2 but coords in other positions:
1449+
1450+
>>> a = xr.DataArray(
1451+
... [1, 2],
1452+
... dims=["cartesian"],
1453+
... coords=dict(cartesian=(["cartesian"], ["x", "z"])),
1454+
... )
1455+
>>> b = xr.DataArray(
1456+
... [4, 5, 6],
1457+
... dims=["cartesian"],
1458+
... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
1459+
... )
1460+
>>> xr.cross(a, b, dim="cartesian")
1461+
<xarray.DataArray (cartesian: 3)>
1462+
array([-10, 2, 5])
1463+
Coordinates:
1464+
* cartesian (cartesian) <U1 'x' 'y' 'z'
1465+
1466+
Multiple vector cross-products. Note that the direction of the
1467+
cross product vector is defined by the right-hand rule:
1468+
1469+
>>> a = xr.DataArray(
1470+
... [[1, 2, 3], [4, 5, 6]],
1471+
... dims=("time", "cartesian"),
1472+
... coords=dict(
1473+
... time=(["time"], [0, 1]),
1474+
... cartesian=(["cartesian"], ["x", "y", "z"]),
1475+
... ),
1476+
... )
1477+
>>> b = xr.DataArray(
1478+
... [[4, 5, 6], [1, 2, 3]],
1479+
... dims=("time", "cartesian"),
1480+
... coords=dict(
1481+
... time=(["time"], [0, 1]),
1482+
... cartesian=(["cartesian"], ["x", "y", "z"]),
1483+
... ),
1484+
... )
1485+
>>> xr.cross(a, b, dim="cartesian")
1486+
<xarray.DataArray (time: 2, cartesian: 3)>
1487+
array([[-3, 6, -3],
1488+
[ 3, -6, 3]])
1489+
Coordinates:
1490+
* time (time) int64 0 1
1491+
* cartesian (cartesian) <U1 'x' 'y' 'z'
1492+
1493+
Cross can be called on Datasets by converting to DataArrays and later
1494+
back to a Dataset:
1495+
1496+
>>> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3])))
1497+
>>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6])))
1498+
>>> c = xr.cross(
1499+
... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian"
1500+
... )
1501+
>>> c.to_dataset(dim="cartesian")
1502+
<xarray.Dataset>
1503+
Dimensions: (dim_0: 1)
1504+
Dimensions without coordinates: dim_0
1505+
Data variables:
1506+
x (dim_0) int64 -3
1507+
y (dim_0) int64 6
1508+
z (dim_0) int64 -3
1509+
1510+
See Also
1511+
--------
1512+
numpy.cross : Corresponding numpy function
1513+
"""
1514+
1515+
if dim not in a.dims:
1516+
raise ValueError(f"Dimension {dim!r} not on a")
1517+
elif dim not in b.dims:
1518+
raise ValueError(f"Dimension {dim!r} not on b")
1519+
1520+
if not 1 <= a.sizes[dim] <= 3:
1521+
raise ValueError(
1522+
f"The size of {dim!r} on a must be 1, 2, or 3 to be "
1523+
f"compatible with a cross product but is {a.sizes[dim]}"
1524+
)
1525+
elif not 1 <= b.sizes[dim] <= 3:
1526+
raise ValueError(
1527+
f"The size of {dim!r} on b must be 1, 2, or 3 to be "
1528+
f"compatible with a cross product but is {b.sizes[dim]}"
1529+
)
1530+
1531+
all_dims = list(dict.fromkeys(a.dims + b.dims))
1532+
1533+
if a.sizes[dim] != b.sizes[dim]:
1534+
# Arrays have different sizes. Append zeros where the smaller
1535+
# array is missing a value, zeros will not affect np.cross:
1536+
1537+
if (
1538+
not isinstance(a, Variable) # Only used to make mypy happy.
1539+
and dim in getattr(a, "coords", {})
1540+
and not isinstance(b, Variable) # Only used to make mypy happy.
1541+
and dim in getattr(b, "coords", {})
1542+
):
1543+
# If the arrays have coords we know which indexes to fill
1544+
# with zeros:
1545+
a, b = align(
1546+
a,
1547+
b,
1548+
fill_value=0,
1549+
join="outer",
1550+
exclude=set(all_dims) - {dim},
1551+
)
1552+
elif min(a.sizes[dim], b.sizes[dim]) == 2:
1553+
# If the array doesn't have coords we can only infer
1554+
# that it has composite values if the size is at least 2.
1555+
# Once padded, rechunk the padded array because apply_ufunc
1556+
# requires core dimensions not to be chunked:
1557+
if a.sizes[dim] < b.sizes[dim]:
1558+
a = a.pad({dim: (0, 1)}, constant_values=0)
1559+
# TODO: Should pad or apply_ufunc handle correct chunking?
1560+
a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a
1561+
else:
1562+
b = b.pad({dim: (0, 1)}, constant_values=0)
1563+
# TODO: Should pad or apply_ufunc handle correct chunking?
1564+
b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b
1565+
else:
1566+
raise ValueError(
1567+
f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:"
1568+
" dimensions without coordinates must have have a length of 2 or 3"
1569+
)
1570+
1571+
c = apply_ufunc(
1572+
np.cross,
1573+
a,
1574+
b,
1575+
input_core_dims=[[dim], [dim]],
1576+
output_core_dims=[[dim] if a.sizes[dim] == 3 else []],
1577+
dask="parallelized",
1578+
output_dtypes=[np.result_type(a, b)],
1579+
)
1580+
c = c.transpose(*all_dims, missing_dims="ignore")
1581+
1582+
return c
1583+
1584+
13761585
def dot(*arrays, dims=None, **kwargs):
13771586
"""Generalized dot product for xarray objects. Like np.einsum, but
13781587
provides a simpler interface based on array dimensions.

xarray/tests/test_computation.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,3 +1952,110 @@ def test_polyval(use_dask, use_datetime) -> None:
19521952
da_pv = xr.polyval(da.x, coeffs)
19531953

19541954
xr.testing.assert_allclose(da, da_pv.T)
1955+
1956+
1957+
@pytest.mark.parametrize("use_dask", [False, True])
1958+
@pytest.mark.parametrize(
1959+
"a, b, ae, be, dim, axis",
1960+
[
1961+
[
1962+
xr.DataArray([1, 2, 3]),
1963+
xr.DataArray([4, 5, 6]),
1964+
[1, 2, 3],
1965+
[4, 5, 6],
1966+
"dim_0",
1967+
-1,
1968+
],
1969+
[
1970+
xr.DataArray([1, 2]),
1971+
xr.DataArray([4, 5, 6]),
1972+
[1, 2],
1973+
[4, 5, 6],
1974+
"dim_0",
1975+
-1,
1976+
],
1977+
[
1978+
xr.Variable(dims=["dim_0"], data=[1, 2, 3]),
1979+
xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
1980+
[1, 2, 3],
1981+
[4, 5, 6],
1982+
"dim_0",
1983+
-1,
1984+
],
1985+
[
1986+
xr.Variable(dims=["dim_0"], data=[1, 2]),
1987+
xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
1988+
[1, 2],
1989+
[4, 5, 6],
1990+
"dim_0",
1991+
-1,
1992+
],
1993+
[ # Test dim in the middle:
1994+
xr.DataArray(
1995+
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
1996+
dims=["time", "cartesian", "var"],
1997+
coords=dict(
1998+
time=(["time"], np.arange(0, 5)),
1999+
cartesian=(["cartesian"], ["x", "y", "z"]),
2000+
var=(["var"], [1, 1.5, 2, 2.5]),
2001+
),
2002+
),
2003+
xr.DataArray(
2004+
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
2005+
dims=["time", "cartesian", "var"],
2006+
coords=dict(
2007+
time=(["time"], np.arange(0, 5)),
2008+
cartesian=(["cartesian"], ["x", "y", "z"]),
2009+
var=(["var"], [1, 1.5, 2, 2.5]),
2010+
),
2011+
),
2012+
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
2013+
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
2014+
"cartesian",
2015+
1,
2016+
],
2017+
[ # Test 1 sized arrays with coords:
2018+
xr.DataArray(
2019+
np.array([1]),
2020+
dims=["cartesian"],
2021+
coords=dict(cartesian=(["cartesian"], ["z"])),
2022+
),
2023+
xr.DataArray(
2024+
np.array([4, 5, 6]),
2025+
dims=["cartesian"],
2026+
coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
2027+
),
2028+
[0, 0, 1],
2029+
[4, 5, 6],
2030+
"cartesian",
2031+
-1,
2032+
],
2033+
[ # Test filling inbetween with coords:
2034+
xr.DataArray(
2035+
[1, 2],
2036+
dims=["cartesian"],
2037+
coords=dict(cartesian=(["cartesian"], ["x", "z"])),
2038+
),
2039+
xr.DataArray(
2040+
[4, 5, 6],
2041+
dims=["cartesian"],
2042+
coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
2043+
),
2044+
[1, 0, 2],
2045+
[4, 5, 6],
2046+
"cartesian",
2047+
-1,
2048+
],
2049+
],
2050+
)
2051+
def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None:
2052+
expected = np.cross(ae, be, axis=axis)
2053+
2054+
if use_dask:
2055+
if not has_dask:
2056+
pytest.skip("test for dask.")
2057+
a = a.chunk()
2058+
b = b.chunk()
2059+
2060+
actual = xr.cross(a, b, dim=dim)
2061+
xr.testing.assert_duckarray_allclose(expected, actual)

0 commit comments

Comments
 (0)