A generic interface for linear algebra backends: code it once, run it on any backend
- Requirements and Installation
- Basic Usage
- List of Types
- List of Methods
- Devices
- Lazy Shapes
- Random Numbers
- Control Flow Cache
pip install backends
The basic use case for the package is to write code that automatically determines the backend to use depending on the types of its arguments.
Example:
import lab as B
import lab.autograd # Load the AutoGrad extension.
import lab.torch # Load the PyTorch extension.
import lab.tensorflow # Load the TensorFlow extension.
import lab.jax # Load the JAX extension.
def objective(matrix):
outer_product = B.matmul(matrix, matrix, tr_b=True)
return B.mean(outer_product)
The AutoGrad, PyTorch, TensorFlow, and JAX extensions are not loaded automatically to
not enforce a dependency on all three frameworks.
An extension can alternatively be loaded via import lab.autograd as B
.
Run it with NumPy and AutoGrad:
>>> import autograd.numpy as np
>>> objective(B.randn(np.float64, 2, 2))
0.15772589216756833
>>> grad(objective)(B.randn(np.float64, 2, 2))
array([[ 0.23519042, -1.06282928],
[ 0.23519042, -1.06282928]])
Run it with TensorFlow:
>>> import tensorflow as tf
>>> objective(B.randn(tf.float64, 2, 2))
<tf.Tensor 'Mean:0' shape=() dtype=float64>
Run it with PyTorch:
>>> import torch
>>> objective(B.randn(torch.float64, 2, 2))
tensor(1.9557, dtype=torch.float64)
Run it with JAX:
>>> import jax
>>> import jax.numpy as jnp
>>> jax.jit(objective)(B.randn(jnp.float32, 2, 2))
DeviceArray(0.3109299, dtype=float32)
>>> jax.jit(jax.grad(objective))(B.randn(jnp.float32, 2, 2))
DeviceArray([[ 0.2525182, -1.26065 ],
[ 0.2525182, -1.26065 ]], dtype=float32)
This section lists all available types, which can be used to check types of objects or extend functions.
Int # Integers
Float # Floating-point numbers
Complex # Complex numbers
Bool # Booleans
Number # Numbers
Numeric # Numerical objects, including booleans
DType # Data type
Framework # Anything accepted by supported frameworks
Device # Any device type
NPNumeric
NPDType
NPRandomState
NP # Anything NumPy
AGNumeric
AGDType
AGRandomState
AG # Anything AutoGrad
TFNumeric
TFDType
TFRandomState
TFDevice
TF # Anything TensorFlow
TorchNumeric
TorchDType
TorchDevice
TorchRandomState
Torch # Anything PyTorch
JAXNumeric
JAXDType
JAXDevice
JAXRandomState
JAX # Anything JAX
This section lists all available constants and methods.
-
Arguments must be given as arguments and keyword arguments must be given as keyword arguments. For example,
sum(tensor, axis=1)
is valid, butsum(tensor, 1)
is not. -
The names of arguments are indicative of their function:
a
,b
, andc
indicate general tensors.dtype
indicates a data type. E.g,np.float32
ortf.float64
; andrand(np.float32)
creates a NumPy random number, whereasrand(tf.float64)
creates a TensorFlow random number. Data types are always given as the first argument.shape
indicates a shape. The dimensions of a shape are always given as separate arguments to the function. E.g.,reshape(tensor, 2, 2)
is valid, butreshape(tensor, (2, 2))
is not.axis
indicates an axis over which the function may perform its action. An axis is always given as a keyword argument.device
refers to a device on which a tensor can placed, which can either be a framework-specific type or a string, e.g."cpu"
.ref
indicates a reference tensor from which properties, like its shape and data type, will be used. E.g.,zeros(tensor)
creates a tensor full of zeros of the same shape and data type astensor
.
See the documentation for more detailed descriptions of each function.
default_dtype # Default data type.
epsilon # Magnitude of diagonal to regularise matrices with.
cholesky_retry_factor # Retry the Cholesky, increasing `epsilon` by a factor at most this.
nan
pi
log_2_pi
dtype(a)
dtype_float(dtype)
dtype_float(a)
dtype_int(dtype)
dtype_int(a)
promote_dtypes(*dtype)
issubdtype(dtype1, dtype2)
isabstract(a)
jit(f, **kw_args)
isnan(a)
real(a)
imag(a)
device(a)
on_device(device)
on_device(a)
set_global_device(device)
to_active_device(a)
zeros(dtype, *shape)
zeros(*shape)
zeros(ref)
ones(dtype, *shape)
ones(*shape)
ones(ref)
zero(dtype)
zero(*refs)
one(dtype)
one(*refs)
eye(dtype, *shape)
eye(*shape)
eye(ref)
linspace(dtype, a, b, num)
linspace(a, b, num)
range(dtype, start, stop, step)
range(dtype, stop)
range(dtype, start, stop)
range(start, stop, step)
range(start, stop)
range(stop)
cast(dtype, a)
identity(a)
round(a)
floor(a)
ceil(a)
negative(a)
abs(a)
sign(a)
sqrt(a)
exp(a)
log(a)
log1p(a)
sin(a)
arcsin(a)
cos(a)
arccos(a)
tan(a)
arctan(a)
tanh(a)
arctanh(a)
loggamma(a)
logbeta(a)
erf(a)
sigmoid(a)
softplus(a)
relu(a)
add(a, b)
subtract(a, b)
multiply(a, b)
divide(a, b)
power(a, b)
minimum(a, b)
maximum(a, b)
leaky_relu(a, alpha)
softmax(a, axis=None)
min(a, axis=None, squeeze=True)
max(a, axis=None, squeeze=True)
sum(a, axis=None, squeeze=True)
prod(a, axis=None, squeeze=True)
mean(a, axis=None, squeeze=True)
std(a, axis=None, squeeze=True)
logsumexp(a, axis=None, squeeze=True)
all(a, axis=None, squeeze=True)
any(a, axis=None, squeeze=True)
nansum(a, axis=None, squeeze=True)
nanprod(a, axis=None, squeeze=True)
nanmean(a, axis=None, squeeze=True)
nanstd(a, axis=None, squeeze=True)
argmin(a, axis=None)
argmax(a, axis=None)
lt(a, b)
le(a, b)
gt(a, b)
ge(a, b)
eq(a, b)
ne(a, b)
bvn_cdf(a, b, c)
cond(condition, f_true, f_false, xs**)
where(condition, a, b)
scan(f, xs, *init_state)
sort(a, axis=-1, descending=False)
argsort(a, axis=-1, descending=False)
quantile(a, q, axis=None)
to_numpy(a)
jit_to_numpy(a) # Caches results for `B.jit`.
transpose(a, perm=None) (alias: t, T)
matmul(a, b, tr_a=False, tr_b=False) (alias: mm, dot)
einsum(equation, *elements)
trace(a, axis1=0, axis2=1)
kron(a, b)
svd(a, compute_uv=True)
eig(a, compute_eigvecs=True)
solve(a, b)
inv(a)
pinv(a)
det(a)
logdet(a)
expm(a)
logm(a)
cholesky(a) (alias: chol)
cholesky_solve(a, b) (alias: cholsolve)
triangular_solve(a, b, lower_a=True) (alias: trisolve)
toeplitz_solve(a, b, c) (alias: toepsolve)
toeplitz_solve(a, c)
outer(a, b)
reg(a, diag=None, clip=True)
pw_dists2(a, b)
pw_dists2(a)
pw_dists(a, b)
pw_dists(a)
ew_dists2(a, b)
ew_dists2(a)
ew_dists(a, b)
ew_dists(a)
pw_sums2(a, b)
pw_sums2(a)
pw_sums(a, b)
pw_sums(a)
ew_sums2(a, b)
ew_sums2(a)
ew_sums(a, b)
ew_sums(a)
set_random_seed(seed)
create_random_state(dtype, seed=0)
global_random_state(dtype)
global_random_state(a)
set_global_random_state(state)
rand(state, dtype, *shape)
rand(dtype, *shape)
rand(*shape)
rand(state, ref)
rand(ref)
randn(state, dtype, *shape)
randn(dtype, *shape)
randn(*shape)
randn(state, ref)
randn(ref)
randcat(state, p, *shape)
randcat(p, *shape)
choice(state, a, *shape, p=None)
choice(a, *shape, p=None)
randint(state, dtype, *shape, lower=0, upper)
randint(dtype, *shape, lower=0, upper)
randint(*shape, lower=0, upper)
randint(state, ref, lower=0, upper)
randint(ref, lower=0, upper)
randperm(state, dtype, n)
randperm(dtype, n)
randperm(n)
randgamma(state, dtype, *shape, alpha, scale)
randgamma(dtype, *shape, alpha, scale)
randgamma(*shape, alpha, scale)
randgamma(state, ref, *, alpha, scale)
randgamma(ref, *, alpha, scale)
randbeta(state, dtype, *shape, alpha, beta)
randbeta(dtype, *shape, alpha, beta)
randbeta(*shape, alpha, beta)
randbeta(state, ref, *, alpha, beta)
randbeta(ref, *, alpha, beta)
shape(a, *dims)
rank(a)
length(a) (alias: size)
is_scalar(a)
expand_dims(a, axis=0, times=1)
squeeze(a, axis=None)
uprank(a, rank=2)
downrank(a, rank=2, preserve=False)
broadcast_to(a, *shape)
diag(a)
diag_extract(a)
diag_construct(a)
flatten(a)
vec_to_tril(a, offset=0)
tril_to_vec(a, offset=0)
stack(*elements, axis=0)
unstack(a, axis=0, squeeze=True)
reshape(a, *shape)
concat(*elements, axis=0)
concat2d(*rows)
tile(a, *repeats)
take(a, indices_or_mask, axis=0)
submatrix(a, indices_or_mask)
You can get the device of a tensor with B.device(a)
,
and you can execute a computation on a device by entering B.on_device(device)
as a context:
with B.on_device("gpu:0"):
a = B.randn(tf.float32, 2, 2)
b = B.randn(tf.float32, 2, 2)
c = a @ b
Within such a context, a tensor that is not on the active device can be moved to the
active device with B.to_active_device(a)
.
You can also globally set the active device with B.set_global_device("gpu:0")
.
If a function is evaluated abstractly, then elements of the shape of a tensor, e.g.
B.shape(a)[0]
, may also be tensors, which can break dispatch.
By entering B.lazy_shapes()
, shapes and elements of shapes will be wrapped in a custom
type to fix this issue.
with B.lazy_shapes():
a = B.eye(2)
print(type(B.shape(a)))
# <class 'lab.shape.Shape'>
print(type(B.shape(a)[0]))
# <class 'lab.shape.Dimension'>
If you call a random number generator without providing a random state, e.g.
B.randn(np.float32, 2)
, the global random state from the corresponding
backend is used.
For JAX, since there is no global random state, LAB provides a JAX global
random state accessible through B.jax_global_random_state
once lab.jax
is loaded.
If you do not want to use a global random state but rather explicitly maintain
one, you can create a random state with B.create_random_state
and then
pass this as the first argument to the random number generators.
The random number generators will then return a tuple containing the updated
random state and the random result.
# Create random state.
state = B.create_random_state(tf.float32, seed=0)
# Generate two random arrays.
state, x = B.randn(state, tf.float32, 2)
state, y = B.randn(state, tf.float32, 2)
Coming soon!