linox is a Python package that provides a collection of linear operators for JAX, enabling efficient and flexible linear algebra operations with lazy evaluation. This package is designed as an JAX alternative to probnum.linops, but it is currently still under development having less and more instable features. It has no dependencies other than JAX and plum for multiple dispatch.
- Lazy Evaluation: All operators support lazy evaluation, allowing for efficient computation of complex linear transformations
- JAX Integration: Built on top of JAX, providing automatic differentiation, parallelization, JIT compilation, and GPU/TPU support
- Composable Operators: Operators can be combined to form complex linear transformations
- Matrix: General matrix operator
- Identity: Identity matrix operator
- Diagonal: Diagonal matrix operator
- Scalar: Scalar multiple of identity
- Zero: Zero matrix operator
- Ones: Matrix of ones operator
- BlockMatrix: General block matrix operator
- BlockMatrix2x2: 2x2 block matrix operator
- BlockDiagonal: Block diagonal matrix operator
- LowRank: General low rank operator
- SymmetricLowRank: Symmetric low rank operator
- IsotropicScalingPlusSymmetricLowRank: Isotropic scaling plus symmetric low rank
- PositiveDiagonalPlusSymmetricLowRank: Positive diagonal plus symmetric low rank
- Kronecker: Kronecker product operator
- Permutation: Permutation matrix operator
- EigenD: Eigenvalue decomposition operator
- Automatic Differentiation: Compute gradients automatically through operator compositions
- JIT Compilation: Speed up computations with just-in-time compilation
- Vectorization: Efficient batch processing of linear operations via e.g. jax.vmap
- GPU/TPU Support: Run computations on accelerators without code changes
- Functional Programming: Pure functions enable better optimization and parallelization
import jax
import jax.numpy as jnp
from linox import Matrix, Diagonal, BlockMatrix
# Create operators
A = Matrix(jnp.array([[1, 2], [3, 4]], dtype=jnp.float32))
D = Diagonal(jnp.array([1, 2], dtype=jnp.float32))
# Compose operators
B = BlockMatrix([[A, D], [D, A]])
# Apply to vector
x = jnp.ones((4,), dtype=jnp.float32)
y = B @ x  # Lazy evaluation
# Parallelize over batch of vectors
x_batched = jnp.ones((10, 4), dtype=jnp.float32)
y_batched = jax.vmap(B)(x_batched)