|
36 | 36 |
|
37 | 37 | if TYPE_CHECKING:
|
38 | 38 | from .coordinates import Coordinates
|
| 39 | + from .dataarray import DataArray |
39 | 40 | from .dataset import Dataset
|
40 | 41 | from .types import T_Xarray
|
41 | 42 |
|
@@ -1373,6 +1374,214 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
|
1373 | 1374 | return corr
|
1374 | 1375 |
|
1375 | 1376 |
|
| 1377 | +def cross( |
| 1378 | + a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable |
| 1379 | +) -> Union[DataArray, Variable]: |
| 1380 | + """ |
| 1381 | + Compute the cross product of two (arrays of) vectors. |
| 1382 | +
|
| 1383 | + The cross product of `a` and `b` in :math:`R^3` is a vector |
| 1384 | + perpendicular to both `a` and `b`. The vectors in `a` and `b` are |
| 1385 | + defined by the values along the dimension `dim` and can have sizes |
| 1386 | + 1, 2 or 3. Where the size of either `a` or `b` is |
| 1387 | + 1 or 2, the remaining components of the input vector is assumed to |
| 1388 | + be zero and the cross product calculated accordingly. In cases where |
| 1389 | + both input vectors have dimension 2, the z-component of the cross |
| 1390 | + product is returned. |
| 1391 | +
|
| 1392 | + Parameters |
| 1393 | + ---------- |
| 1394 | + a, b : DataArray or Variable |
| 1395 | + Components of the first and second vector(s). |
| 1396 | + dim : hashable |
| 1397 | + The dimension along which the cross product will be computed. |
| 1398 | + Must be available in both vectors. |
| 1399 | +
|
| 1400 | + Examples |
| 1401 | + -------- |
| 1402 | + Vector cross-product with 3 dimensions: |
| 1403 | +
|
| 1404 | + >>> a = xr.DataArray([1, 2, 3]) |
| 1405 | + >>> b = xr.DataArray([4, 5, 6]) |
| 1406 | + >>> xr.cross(a, b, dim="dim_0") |
| 1407 | + <xarray.DataArray (dim_0: 3)> |
| 1408 | + array([-3, 6, -3]) |
| 1409 | + Dimensions without coordinates: dim_0 |
| 1410 | +
|
| 1411 | + Vector cross-product with 2 dimensions, returns in the perpendicular |
| 1412 | + direction: |
| 1413 | +
|
| 1414 | + >>> a = xr.DataArray([1, 2]) |
| 1415 | + >>> b = xr.DataArray([4, 5]) |
| 1416 | + >>> xr.cross(a, b, dim="dim_0") |
| 1417 | + <xarray.DataArray ()> |
| 1418 | + array(-3) |
| 1419 | +
|
| 1420 | + Vector cross-product with 3 dimensions but zeros at the last axis |
| 1421 | + yields the same results as with 2 dimensions: |
| 1422 | +
|
| 1423 | + >>> a = xr.DataArray([1, 2, 0]) |
| 1424 | + >>> b = xr.DataArray([4, 5, 0]) |
| 1425 | + >>> xr.cross(a, b, dim="dim_0") |
| 1426 | + <xarray.DataArray (dim_0: 3)> |
| 1427 | + array([ 0, 0, -3]) |
| 1428 | + Dimensions without coordinates: dim_0 |
| 1429 | +
|
| 1430 | + One vector with dimension 2: |
| 1431 | +
|
| 1432 | + >>> a = xr.DataArray( |
| 1433 | + ... [1, 2], |
| 1434 | + ... dims=["cartesian"], |
| 1435 | + ... coords=dict(cartesian=(["cartesian"], ["x", "y"])), |
| 1436 | + ... ) |
| 1437 | + >>> b = xr.DataArray( |
| 1438 | + ... [4, 5, 6], |
| 1439 | + ... dims=["cartesian"], |
| 1440 | + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), |
| 1441 | + ... ) |
| 1442 | + >>> xr.cross(a, b, dim="cartesian") |
| 1443 | + <xarray.DataArray (cartesian: 3)> |
| 1444 | + array([12, -6, -3]) |
| 1445 | + Coordinates: |
| 1446 | + * cartesian (cartesian) <U1 'x' 'y' 'z' |
| 1447 | +
|
| 1448 | + One vector with dimension 2 but coords in other positions: |
| 1449 | +
|
| 1450 | + >>> a = xr.DataArray( |
| 1451 | + ... [1, 2], |
| 1452 | + ... dims=["cartesian"], |
| 1453 | + ... coords=dict(cartesian=(["cartesian"], ["x", "z"])), |
| 1454 | + ... ) |
| 1455 | + >>> b = xr.DataArray( |
| 1456 | + ... [4, 5, 6], |
| 1457 | + ... dims=["cartesian"], |
| 1458 | + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), |
| 1459 | + ... ) |
| 1460 | + >>> xr.cross(a, b, dim="cartesian") |
| 1461 | + <xarray.DataArray (cartesian: 3)> |
| 1462 | + array([-10, 2, 5]) |
| 1463 | + Coordinates: |
| 1464 | + * cartesian (cartesian) <U1 'x' 'y' 'z' |
| 1465 | +
|
| 1466 | + Multiple vector cross-products. Note that the direction of the |
| 1467 | + cross product vector is defined by the right-hand rule: |
| 1468 | +
|
| 1469 | + >>> a = xr.DataArray( |
| 1470 | + ... [[1, 2, 3], [4, 5, 6]], |
| 1471 | + ... dims=("time", "cartesian"), |
| 1472 | + ... coords=dict( |
| 1473 | + ... time=(["time"], [0, 1]), |
| 1474 | + ... cartesian=(["cartesian"], ["x", "y", "z"]), |
| 1475 | + ... ), |
| 1476 | + ... ) |
| 1477 | + >>> b = xr.DataArray( |
| 1478 | + ... [[4, 5, 6], [1, 2, 3]], |
| 1479 | + ... dims=("time", "cartesian"), |
| 1480 | + ... coords=dict( |
| 1481 | + ... time=(["time"], [0, 1]), |
| 1482 | + ... cartesian=(["cartesian"], ["x", "y", "z"]), |
| 1483 | + ... ), |
| 1484 | + ... ) |
| 1485 | + >>> xr.cross(a, b, dim="cartesian") |
| 1486 | + <xarray.DataArray (time: 2, cartesian: 3)> |
| 1487 | + array([[-3, 6, -3], |
| 1488 | + [ 3, -6, 3]]) |
| 1489 | + Coordinates: |
| 1490 | + * time (time) int64 0 1 |
| 1491 | + * cartesian (cartesian) <U1 'x' 'y' 'z' |
| 1492 | +
|
| 1493 | + Cross can be called on Datasets by converting to DataArrays and later |
| 1494 | + back to a Dataset: |
| 1495 | +
|
| 1496 | + >>> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) |
| 1497 | + >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) |
| 1498 | + >>> c = xr.cross( |
| 1499 | + ... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian" |
| 1500 | + ... ) |
| 1501 | + >>> c.to_dataset(dim="cartesian") |
| 1502 | + <xarray.Dataset> |
| 1503 | + Dimensions: (dim_0: 1) |
| 1504 | + Dimensions without coordinates: dim_0 |
| 1505 | + Data variables: |
| 1506 | + x (dim_0) int64 -3 |
| 1507 | + y (dim_0) int64 6 |
| 1508 | + z (dim_0) int64 -3 |
| 1509 | +
|
| 1510 | + See Also |
| 1511 | + -------- |
| 1512 | + numpy.cross : Corresponding numpy function |
| 1513 | + """ |
| 1514 | + |
| 1515 | + if dim not in a.dims: |
| 1516 | + raise ValueError(f"Dimension {dim!r} not on a") |
| 1517 | + elif dim not in b.dims: |
| 1518 | + raise ValueError(f"Dimension {dim!r} not on b") |
| 1519 | + |
| 1520 | + if not 1 <= a.sizes[dim] <= 3: |
| 1521 | + raise ValueError( |
| 1522 | + f"The size of {dim!r} on a must be 1, 2, or 3 to be " |
| 1523 | + f"compatible with a cross product but is {a.sizes[dim]}" |
| 1524 | + ) |
| 1525 | + elif not 1 <= b.sizes[dim] <= 3: |
| 1526 | + raise ValueError( |
| 1527 | + f"The size of {dim!r} on b must be 1, 2, or 3 to be " |
| 1528 | + f"compatible with a cross product but is {b.sizes[dim]}" |
| 1529 | + ) |
| 1530 | + |
| 1531 | + all_dims = list(dict.fromkeys(a.dims + b.dims)) |
| 1532 | + |
| 1533 | + if a.sizes[dim] != b.sizes[dim]: |
| 1534 | + # Arrays have different sizes. Append zeros where the smaller |
| 1535 | + # array is missing a value, zeros will not affect np.cross: |
| 1536 | + |
| 1537 | + if ( |
| 1538 | + not isinstance(a, Variable) # Only used to make mypy happy. |
| 1539 | + and dim in getattr(a, "coords", {}) |
| 1540 | + and not isinstance(b, Variable) # Only used to make mypy happy. |
| 1541 | + and dim in getattr(b, "coords", {}) |
| 1542 | + ): |
| 1543 | + # If the arrays have coords we know which indexes to fill |
| 1544 | + # with zeros: |
| 1545 | + a, b = align( |
| 1546 | + a, |
| 1547 | + b, |
| 1548 | + fill_value=0, |
| 1549 | + join="outer", |
| 1550 | + exclude=set(all_dims) - {dim}, |
| 1551 | + ) |
| 1552 | + elif min(a.sizes[dim], b.sizes[dim]) == 2: |
| 1553 | + # If the array doesn't have coords we can only infer |
| 1554 | + # that it has composite values if the size is at least 2. |
| 1555 | + # Once padded, rechunk the padded array because apply_ufunc |
| 1556 | + # requires core dimensions not to be chunked: |
| 1557 | + if a.sizes[dim] < b.sizes[dim]: |
| 1558 | + a = a.pad({dim: (0, 1)}, constant_values=0) |
| 1559 | + # TODO: Should pad or apply_ufunc handle correct chunking? |
| 1560 | + a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a |
| 1561 | + else: |
| 1562 | + b = b.pad({dim: (0, 1)}, constant_values=0) |
| 1563 | + # TODO: Should pad or apply_ufunc handle correct chunking? |
| 1564 | + b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b |
| 1565 | + else: |
| 1566 | + raise ValueError( |
| 1567 | + f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:" |
| 1568 | + " dimensions without coordinates must have have a length of 2 or 3" |
| 1569 | + ) |
| 1570 | + |
| 1571 | + c = apply_ufunc( |
| 1572 | + np.cross, |
| 1573 | + a, |
| 1574 | + b, |
| 1575 | + input_core_dims=[[dim], [dim]], |
| 1576 | + output_core_dims=[[dim] if a.sizes[dim] == 3 else []], |
| 1577 | + dask="parallelized", |
| 1578 | + output_dtypes=[np.result_type(a, b)], |
| 1579 | + ) |
| 1580 | + c = c.transpose(*all_dims, missing_dims="ignore") |
| 1581 | + |
| 1582 | + return c |
| 1583 | + |
| 1584 | + |
1376 | 1585 | def dot(*arrays, dims=None, **kwargs):
|
1377 | 1586 | """Generalized dot product for xarray objects. Like np.einsum, but
|
1378 | 1587 | provides a simpler interface based on array dimensions.
|
|
0 commit comments