Description
I'd like to add a Matrix
class to complement Vector
.
A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?
If we only need to support a single "tree axis", then most Matrix
operations can be implemented essentially by calling vmap
on a Vector
, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.
In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of jax.jacobian
on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: jax-ml/jax#3263.
My inclination is to only implement the "single tree-axis" version of matrix
, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implement jax.jacobian
(and variations).