Skip to content

Add xtensor docs #1504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions doc/conf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import inspect
import sys

import pytensor
from pathlib import Path

Expand All @@ -12,6 +13,7 @@
"sphinx.ext.autodoc",
"sphinx.ext.todo",
"sphinx.ext.doctest",
"sphinx_copybutton",
"sphinx.ext.napoleon",
"sphinx.ext.linkcode",
"sphinx.ext.mathjax",
Expand Down Expand Up @@ -86,8 +88,7 @@

# List of directories, relative to source directories, that shouldn't be
# searched for source files.
exclude_dirs = ["images", "scripts", "sandbox"]
exclude_patterns = ['page_footer.md', '**/*.myst.md']
exclude_patterns = ["README.md", "images/*", "page_footer.md", "**/*.myst.md"]

# The reST default role (used for this markup: `text`) to use for all
# documents.
Expand Down Expand Up @@ -235,24 +236,41 @@
# Resolve function
# This function is used to populate the (source) links in the API
def linkcode_resolve(domain, info):
def find_source():
def find_obj() -> object:
# try to find the file and line number, based on code from numpy:
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
obj = sys.modules[info["module"]]
for part in info["fullname"].split("."):
obj = getattr(obj, part)
return obj

def find_source(obj):
fn = Path(inspect.getsourcefile(obj))
fn = fn.relative_to(Path(__file__).parent)
fn = fn.relative_to(Path(pytensor.__file__).parent)
source, lineno = inspect.getsourcelines(obj)
return fn, lineno, lineno + len(source) - 1

def fallback_source():
return info["module"].replace(".", "/") + ".py"

if domain != "py" or not info["module"]:
return None

try:
filename = "pytensor/%s#L%d-L%d" % find_source()
obj = find_obj()
except Exception:
filename = info["module"].replace(".", "/") + ".py"
filename = fallback_source()
else:
try:
filename = "pytensor/%s#L%d-L%d" % find_source(obj)
except Exception:
# warnings.warn(f"Could not find source code for {domain}:{info}")
try:
filename = obj.__module__.replace(".", "/") + ".py"
except AttributeError:
# Some objects do not have a __module__ attribute (?)
filename = fallback_source()

import subprocess

tag = subprocess.Popen(
Expand Down
3 changes: 2 additions & 1 deletion doc/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ dependencies:
- mock
- pillow
- pymc-sphinx-theme
- sphinx-copybutton
- sphinx-design
- sphinx-sitemap
- pygments
- pydot
- ipython
Expand All @@ -23,5 +25,4 @@ dependencies:
- ablog
- pip
- pip:
- sphinx_sitemap
- -e ..
4 changes: 1 addition & 3 deletions doc/library/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ Modules
d3viz/index
graph/index
gradient
misc/pkl_utils
printing
scalar/index
scan
sparse/index
sparse/sandbox
tensor/index
typed_list
xtensor/index

.. module:: pytensor
:platform: Unix, Windows
Expand Down
101 changes: 101 additions & 0 deletions doc/library/xtensor/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
(libdoc_xtensor)=
# `xtensor` -- XTensor operations

This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.

A new type {class}`pytensor.xtensor.type.XTensorType`, generalizes the {class}`pytensor.tensor.TensorType`
with the addition of a `dims` attribute, that labels the dimensions of the tensor.

Variables of XTensorType (i.e., {class}`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart
to xarray DataArray objects.

The module implements several PyTensor operations {class}`pytensor.xtensor.basic.XOp`s, whose signature mimics that of
xarray (and xarray_einstats) DataArray operations. These operations, unlike most regular PyTensor operations, cannot
be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual.

Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.

## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.

## Example


```{testcode}

import pytensor.tensor as pt
import pytensor.xtensor as ptx

a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))

ax = ptx.as_xtensor(a, dims=["x"])
bx = ptx.as_xtensor(b, dims=["y"])

zx = ax + bx
assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))

z = zx.values
z.dprint()
```


```{testoutput}

TensorFromXTensor [id A]
└─ XElemwise{scalar_op=Add()} [id B]
├─ XTensorFromTensor{dims=('x',)} [id C]
│ └─ a [id D]
└─ XTensorFromTensor{dims=('y',)} [id E]
└─ b [id F]
```

Once we compile the graph, no XOps are left.

```{testcode}

import pytensor

with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)

```

```{testoutput}

rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)

```

```{testcode}

fn.dprint()
```

```{testoutput}

Add [id A] 2
├─ ExpandDims{axis=1} [id B] 1
│ └─ a [id C]
└─ ExpandDims{axis=0} [id D] 0
└─ b [id E]
```


## Index

:::{toctree}
:maxdepth: 1

module_functions
math
linalg
random
type
:::
7 changes: 7 additions & 0 deletions doc/library/xtensor/linalg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(libdoc_xtensor_linalg)=
# `xtensor.linalg` -- Linear algebra operations

```{eval-rst}
.. automodule:: pytensor.xtensor.linalg
:members:
```
8 changes: 8 additions & 0 deletions doc/library/xtensor/math.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
(libdoc_xtensor_math)=
# `xtensor.math` Mathematical operations

```{eval-rst}
.. automodule:: pytensor.xtensor.math
:members:
:exclude-members: XDot, dot
```
7 changes: 7 additions & 0 deletions doc/library/xtensor/module_functions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(libdoc_xtensor_module_function)=
# `xtensor` -- Module level operations

```{eval-rst}
.. automodule:: pytensor.xtensor
:members: broadcast, concat, dot, full_like, ones_like, zeros_like
```
7 changes: 7 additions & 0 deletions doc/library/xtensor/random.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(libdoc_xtensor_random)=
# `xtensor.random` Random number generator operations

```{eval-rst}
.. automodule:: pytensor.xtensor.random
:members:
```
21 changes: 21 additions & 0 deletions doc/library/xtensor/type.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
(libdoc_xtenor_type)=

# `xtensor.type` -- Types and Variables

## XTensorVariable creation functions

```{eval-rst}
.. automodule:: pytensor.xtensor.type
:members: xtensor, xtensor_constant, as_xtensor

```

## XTensor Type and Variable classes

```{eval-rst}
.. automodule:: pytensor.xtensor.type
:noindex:
:members: XTensorType, XTensorVariable, XTensorConstant
```


2 changes: 1 addition & 1 deletion pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random
from pytensor.xtensor import linalg, math, random
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import (
Expand Down
42 changes: 40 additions & 2 deletions pytensor/xtensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,31 @@ def cholesky(
lower: bool = True,
*,
check_finite: bool = False,
overwrite_a: bool = False,
on_error: Literal["raise", "nan"] = "raise",
dims: Sequence[str],
):
"""Compute the Cholesky decomposition of an XTensorVariable.

Parameters
----------
x : XTensorVariable
The input variable to decompose.
lower : bool, optional
Whether to return the lower triangular matrix. Default is True.
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
on_error : {'raise', 'nan'}, optional
What to do if the input is not positive definite. If 'raise', an error is raised.
If 'nan', the output will contain NaNs. Default is 'raise'.
dims : Sequence[str]
The two core dimensions of the input variable, over which the Cholesky decomposition is computed.
"""
if len(dims) != 2:
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")

core_op = Cholesky(
lower=lower,
check_finite=check_finite,
overwrite_a=overwrite_a,
on_error=on_error,
)
core_dims = (
Expand All @@ -40,6 +54,30 @@ def solve(
lower: bool = False,
check_finite: bool = False,
):
"""Solve a system of linear equations using XTensorVariables.

Parameters
----------
a : XTensorVariable
The left hand-side xtensor.
b : XTensorVariable
The right-hand side xtensor.
dims : Sequence[str]
The core dimensions over which to solve the linear equations.
If length is 2, we are solving a matrix-vector equation,
and the two dimensions should be present in `a`, but only one in `b`.
If length is 3, we are solving a matrix-matrix equation,
and two dimensions should be present in `a`, two in `b`, and only one should be shared.
In both cases the shared dimension will not appear in the output.
assume_a : str, optional
The type of matrix `a` is assumed to be. Default is 'gen' (general).
Options are ["gen", "sym", "her", "pos", "tridiagonal", "banded"].
Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"].
lower : bool, optional
Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos".
check_finite : bool, optional
Whether to check that the input is finite. Default is False.
"""
a, b = as_xtensor(a), as_xtensor(b)
input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]]
output_core_dims: tuple[tuple[str] | tuple[str, str]]
Expand Down
Loading