Painless optimisation of constrained variables in AutoGrad, TensorFlow, PyTorch, and JAX
See the instructions here. Then simply
pip install varz
from varz import Vars
To begin with, create a variable container of the right data type.
For use with AutoGrad, use a np.*
data type;
for use with PyTorch, use a torch.*
data type;
for use with TensorFlow, use a tf.*
data type;
and for use with JAX, use a jnp.*
data type.
In this example we'll use AutoGrad.
>>> vs = Vars(np.float64)
Now a variable can be created by requesting it, giving it an initial value and a name.
>>> vs.unbounded(np.random.randn(2, 2), name="x")
array([[ 1.04404354, -1.98478763],
[ 1.14176728, -3.2915562 ]])
If the same variable is created again, because a variable with the name x
already exists, the existing variable will be returned, even if you again pass it an
initial value.
>>> vs.unbounded(np.random.randn(2, 2), name="x")
array([[ 1.04404354, -1.98478763],
[ 1.14176728, -3.2915562 ]])
>>> vs.unbounded(name="x")
array([[ 1.04404354, -1.98478763],
[ 1.14176728, -3.2915562 ]])
Alternatively, indexing syntax may be used to get the existing variable x
.
This asserts that a variable with the name x
already exists and will throw a
KeyError
otherwise.
>>> vs["x"]
array([[ 1.04404354, -1.98478763],
[ 1.14176728, -3.2915562 ]])
>>> vs["y"]
KeyError: 'y'
The value of x
can be changed by assigning it a different value.
>>> vs.assign("x", np.random.randn(2, 2))
array([[ 1.43477728, 0.51006941],
[-0.74686452, -1.05285767]])
By default, assignment is non-differentiable and overwrites data.
The variable can be deleted by passing its name to vs.delete
:
>>> vs.delete("x")
>>> vs["x"]
KeyError: 'x'
When a variable is first created, you can set the keyword argument visible
to False
if you want to make the variable invisible to the
variable-aggregating operations vs.get_latent_vars
and
vs.get_latent_vector
.
These variable-aggregating operations are used in optimisers to get the intended
collection of variable to optimise.
Therefore, setting visible
to False
will prevent a variable from being
optimised.
Finally, a variable container can be copied with vs.copy()
.
Copies are lightweight and share their variables with the originals.
As a consequence, however, assignment in a copy will also mutate the original.
Differentiable assignment, however, will not.
Variables may be organised by naming them hierarchically using .
s.
For example, you could name like group1.bar
, group1.foo
, and group2.bar
.
This is helpful for extracting collections of variables, where wildcards may
be used to match names.
For example, *.bar
would match group1.bar
and group2.bar
, and
group1.*
would match group1.bar
and group1.foo
.
See also here.
The names of all variables can be obtained with Vars.names
, and variables can
be printed with Vars.print
.
Example:
>>> vs = Vars(np.float64)
>>> vs.unbounded(1, name="x1")
array(1.)
>>> vs.unbounded(2, name="x2")
array(2.)
>>> vs.unbounded(3, name="y")
array(3.)
>>> vs.names
['x1', 'x2', 'y']
>>> vs.print()
x1: 1.0
x2: 2.0
y: 3.0
-
Unbounded variables: A variable that is unbounded can be created using
Vars.unbounded
orVars.ubnd
.>>> vs.ubnd(name="normal_variable") 0.016925610008314832
-
Positive variables: A variable that is constrained to be positive can be created using
Vars.positive
orVars.pos
.>>> vs.pos(name="positive_variable") 0.016925610008314832
-
Bounded variables: A variable that is constrained to be bounded can be created using
Vars.bounded
orVars.bnd
.>>> vs.bnd(name="bounded_variable", lower=1, upper=2) 1.646772663807718
-
Lower-triangular matrix: A matrix variable that is constrained to be lower triangular can be created using
Vars.lower_triangular
orVars.tril
. Either an initialisation or a shape of square matrix must be given.>>> vs.tril(shape=(2, 2), name="lower_triangular") array([[ 2.64204459, 0. ], [-0.14055559, -1.91298679]])
-
Positive-definite matrix: A matrix variable that is contrained to be positive definite can be created using
Vars.positive_definite
orVars.pd
. Either an initialisation or a shape of square matrix must be given.>>> vs.pd(shape=(2, 2), name="positive_definite") array([[ 1.64097496, -0.52302151], [-0.52302151, 0.32628302]])
-
Orthogonal matrix: A matrix variable that is constrained to be orthogonal can be created using
Vars.orthogonal
orVars.orth
. Either an initialisation or a shape of square matrix must be given.>>> vs.orth(shape=(2, 2), name="orthogonal") array([[ 0.31290403, -0.94978475], [ 0.94978475, 0.31290403]])
These constrained variables are created by transforming some latent
unconstrained representation to the desired constrained space.
The latent variables can be obtained using Vars.get_latent_vars
.
>>> vs.get_latent_vars("positive_variable", "bounded_variable")
[array(-4.07892742), array(-0.604883)]
To illustrate the use of wildcards, the following is equivalent:
>>> vs.get_latent_vars("*_variable")
[array(-4.07892742), array(-0.604883)]
Variables can be excluded by prepending a dash:
>>> vs.get_latent_vars("*_variable", "-bounded_*")
[array(-4.07892742)]
To parametrise functions, a common pattern is the following:
def objective(vs):
x = vs.unbounded(5, name="x")
y = vs.unbounded(10, name="y")
return (x * y - 5) ** 2 + x ** 2
The names for x
and y
are necessary, because otherwise new variables will
be created and initialised every time objective
is run.
Varz offers two ways to not having to specify a name for every variable:
sequential and parametrised specification.
Sequential specification can be used if, upon execution of objective
,
variables are always obtained in the same order.
This means that variables can be identified with their position in this order
and hence be named accordingly.
To use sequential specification, decorate the function with sequential
.
Example:
from varz import sequential
@sequential
def objective(vs):
x = vs.unbounded(5) # Initialise to 5.
y = vs.unbounded() # Initialise randomly.
return (x * y - 5) ** 2 + x ** 2
>>> vs = Vars(np.float64)
>>> objective(vs)
68.65047879833773
>>> objective(vs) # Running the objective again reuses the same variables.
68.65047879833773
>>> vs.names
['var0', 'var1']
>>> vs.print()
var0: 5.0 # This is `x`.
var1: -0.3214 # This is `y`.
Sequential specification still suffers from boilerplate code like
x = vs.unbounded(5)
and y = vs.unbounded()
.
This is the problem that parametrised specification addresses, which allows
you to specify variables as arguments to your function.
Import from varz.spec import parametrised
.
To indicate that an argument of the function is a variable, as opposed to a
regular argument, the argument's type hint must be set accordingly, as follows:
-
Unbounded variables:
@parametrised def f(vs, x: Unbounded): ...
-
Positive variables:
@parametrised def f(vs, x: Positive): ...
-
Bounded variables: The following two specifications are possible. The former uses the default bounds and the latter uses specified bounds.
@parametrised def f(vs, x: Bounded): ...
@parametrised def f(vs, x: Bounded(lower=1, upper=10)): ...
-
Lower-triangular variables:
@parametrised def f(vs, x: LowerTriangular(shape=(2, 2))): ...
-
Positive-definite variables:
@parametrised def f(vs, x: PositiveDefinite(shape=(2, 2))): ...
-
Orthogonal variables:
@parametrised def f(vs, x: Orthogonal(shape=(2, 2))): ...
As can be seen from the above, the variable container must also be an argument of the function, because that is where the variables will be obtained from. A variable can be given an initial value in the way you would expect:
@parametrised
def f(vs, x: Unbounded = 5):
...
Variable arguments and regular arguments can be mixed.
If f
is called, variable arguments must not be specified, because they
will be obtained automatically.
Regular arguments, however, must be specified.
To use parametrised specification, decorate the function with parametrised
.
Example:
from varz import parametrised, Unbounded, Bounded
@parametrised
def objective(vs, x: Unbounded, y: Bounded(lower=1, upper=3) = 2, option=None):
print("Option:", option)
return (x * y - 5) ** 2 + x ** 2
>>> vs = Vars(np.float64)
>>> objective(vs)
Option: None
9.757481795615316
>>> objective(vs, "other")
Option: other
9.757481795615316
>>> objective(vs, option="other")
Option: other
9.757481795615316
>>> objective(vs, x=5) # This is not valid, because `x` will be obtained automatically from `vs`.
ValueError: 1 keyword argument(s) not parsed: x.
>>> vs.print()
x: 1.025
y: 2.0
Namespaces can be used to group all variables in a function together.
Example:
from varz import namespace
@namespace("test")
def objective(vs):
x = vs.unbounded(5, name="x")
y = vs.unbounded(name="y")
return x + y
>>> vs = Vars(np.float64)
>>> objective(vs)
6.12448906632577
>>> vs.names
['test.x', 'test.y']
>>> vs.print()
test.x: 5.0
test.y: 1.124
You can combine namespace with other specification methods:
from varz import namespace
@namespace("test")
@sequential
def objective(vs):
x = vs.unbounded(5)
y = vs.unbounded()
return x + y
>>> vs = Vars(np.float64)
>>> objective(vs)
4.812730329303665
>>> vs.names
['test.var0', 'test.var1']
>>> vs.print()
test.var0: 5.0
test.var1: -0.1873
For any variable container vs
, vs.struct
gives an object which you can treat like
nested struct, list, or dictionary to automatically generate variable names.
For example, vs.struct.model["a"].variance.positive()
would be equivalent to
vs.positive(name="model[a].variance")
.
After variables have been defined in this way, they also be extracted via vs.struct
:
vs.struct.model["a"].variance()
would be equivalent to vs["model[a].variance"]
.
Example:
def objective(vs):
params = vs.struct
x = params.x.unbounded()
y = params.y.unbounded()
for model_params, model in zip(params.models, [object(), object(), object()]):
model_params.specific_parameter1.positive()
model_params.specific_parameter2.positive()
return x + y
>>> vs = Vars(np.float64)
>>> objective(vs)
-0.08322955725015702
>>> vs.names
['x',
'y',
'models[0].specific_parameter1',
'models[0].specific_parameter2',
'models[1].specific_parameter1',
'models[1].specific_parameter2',
'models[2].specific_parameter1',
'models[2].specific_parameter2']
>>> vs.print()
x: -0.8963
y: 0.8131
models[0].specific_parameter1: 0.01855
models[0].specific_parameter2: 0.6644
models[1].specific_parameter1: 0.3542
models[1].specific_parameter2: 0.3642
models[2].specific_parameter1: 0.5807
models[2].specific_parameter2: 0.5977
>>> vs.struct.models[0].specific_parameter1()
0.018551827512328086
>>> vs.struct.models[0].specific_parameter2()
0.6643533007198247
There are a few methods available for convenient manipulation of the variable struct.
In the following, let params = vs.struct
.
- Go up a directory:
params.a.b.c.up()
goes up one directory and givesparams.a.b
. If you want to be sure about which directory you are going up, you can pass the name of the directory you want to go up as an argument:params.a.b.c.up("c")
will give the intended result, butparams.a.b.c.up("b")
will result in an assertion error. - Get all variables in a path:
params.a.all()
gives the regexa.*
. - Check if a variable exists:
bool(params.a)
givesTrue
ifa
is a defined variable andFalse
otherwise. - Assign a value to a variable:
params.a.assign(1)
assigns1
toa
. - Delete a variable:
params.a.delete()
deletesa
.
The following optimisers are available:
varz.{autograd,tensorflow,torch,jax}.minimise_l_bfgs_b (L-BFGS-B)
varz.{autograd,tensorflow,torch,jax}.minimise_adam (ADAM)
The L-BFGS-B algorithm is recommended for deterministic objectives and ADAM is recommended for stochastic objectives.
See the examples for an illustration of how these optimisers can be used. Some commonly used keyword arguments are as follows:
Keyword Argument | Description |
---|---|
iters |
Number of iterations |
trace |
Show progress |
jit |
Use a JIT to compile the gradient |
See the API for a detailed description of the keyword arguments that these optimisers accept.
All the variables held by a container can be detached from the current
computation graph with Vars.detach
.
To make a copy of the container with detached versions of the variables, use
Vars.copy
with detach=True
instead.
Whether variables require gradients can be configured with Vars.requires_grad
.
By default, no variable requires a gradient.
It may be desirable to get the latent representations of a collection of
variables as a single vector, e.g. when feeding them to an optimiser.
This can be achieved with Vars.get_latent_vector
.
>>> vs.get_latent_vector("x", "*_variable")
array([0.12500578, -0.21510423, -0.61336039, 1.23074066, -4.07892742,
-0.604883])
Similarly, to update the latent representation of a collection of variables,
Vars.set_latent_vector
can be used.
>>> vs.set_latent_vector(np.ones(6), "x", "*_variable")
[array([[1., 1.],
[1., 1.]]), array(1.), array(1.)]
>>> vs.get_latent_vector("x", "*_variable")
array([1., 1., 1., 1., 1., 1.])
By default, Vars.set_latent_vector
will overwrite the variables, just like
Vars.assign
.
This has as an unfortunate consequence that you cannot differentiate with respect to
the assigned values.
To be able to differentiable with respect to the assigned values, set the keyword
differentiable=True
in the call to Vars.set_latent_vector
.
Unlike regular assignment, if the variable container is a copy of some original,
differentiable assignment will not mutate the variables in the original.
The keyword argument source
can set to a tensor from which the latent
variables will be obtained.
Example:
>>> vs = Vars(np.float32, source=np.array([1, 2, 3, 4, 5]))
>>> vs.unbounded()
array(1., dtype=float32)
>>> vs.unbounded(shape=(3,))
array([2., 3., 4.], dtype=float32)
>>> vs.pos()
148.41316
>>> np.exp(5).astype(np.float32)
148.41316
To create and optimise variables on a GPU,
set the active device to a GPU.
The easiest way of doing this is to import lab as B
and
B.set_global_device("gpu:0")
.
import autograd.numpy as np
from varz.autograd import Vars, minimise_l_bfgs_b
target = 5.0
def objective(vs):
# Get a variable named "x", which must be positive, initialised to 10.
x = vs.pos(10.0, name="x")
return (x ** 0.5 - target) ** 2
>>> vs = Vars(np.float64)
>>> minimise_l_bfgs_b(objective, vs)
3.17785950743424e-19 # Final objective function value.
>>> vs['x'] - target ** 2
-5.637250666268301e-09
import tensorflow as tf
from varz.tensorflow import Vars, minimise_l_bfgs_b
target = 5.0
def objective(vs):
# Get a variable named "x", which must be positive, initialised to 10.
x = vs.pos(10.0, name="x")
return (x ** 0.5 - target) ** 2
>>> vs = Vars(tf.float64)
>>> minimise_l_bfgs_b(objective, vs)
3.17785950743424e-19 # Final objective function value.
>>> vs['x'] - target ** 2
<tf.Tensor: id=562, shape=(), dtype=float64, numpy=-5.637250666268301e-09>
>>> vs = Vars(tf.float64)
>>> minimise_l_bfgs_b(objective, vs, jit=True) # Speed up optimisation with TF's JIT!
3.17785950743424e-19
import torch
from varz.torch import Vars, minimise_l_bfgs_b
target = torch.tensor(5.0, dtype=torch.float64)
def objective(vs):
# Get a variable named "x", which must be positive, initialised to 10.
x = vs.pos(10.0, name="x")
return (x ** 0.5 - target) ** 2
>>> vs = Vars(torch.float64)
>>> minimise_l_bfgs_b(objective, vs)
array(3.17785951e-19) # Final objective function value.
>>> vs["x"] - target ** 2
tensor(-5.6373e-09, dtype=torch.float64)
>>> vs = Vars(torch.float64)
>>> minimise_l_bfgs_b(objective, vs, jit=True) # Speed up optimisation with PyTorch's JIT!
array(3.17785951e-19)
import jax.numpy as jnp
from varz.jax import Vars, minimise_l_bfgs_b
target = 5.0
def objective(vs):
# Get a variable named "x", which must be positive, initialised to 10.
x = vs.pos(10.0, name="x")
return (x ** 0.5 - target) ** 2
>>> vs = Vars(jnp.float64)
>>> minimise_l_bfgs_b(objective, vs)
array(3.17785951e-19) # Final objective function value.
>>> vs["x"] - target ** 2
-5.637250666268301e-09
>>> vs = Vars(jnp.float64)
>>> minimise_l_bfgs_b(objective, vs, jit=True) # Speed up optimisation with Jax's JIT!
array(3.17785951e-19)
import jax.numpy as jnp
from varz.jax import Vars, minimise_l_bfgs_b
target = 5.0
def objective(vs, prev_x):
# Get a variable named "x", which must be positive, initialised to 10.
x = vs.pos(10.0, name="x")
# In addition to the objective function value, also return `x` so that
# we can log it.
return (x ** 0.5 - target) ** 2, x
objs = []
xs = []
def callback(obj, x):
objs.append(obj)
xs.append(x)
# Return a dictionary of extra information to show in the progress display.
return {"x": x}
>>> vs = Vars(jnp.float64)
>>> minimise_l_bfgs_b(objective, (vs, 0), trace=True, jit=True, callback=callback)
Minimisation of "objective":
Iteration 1/1000:
Time elapsed: 0.0 s
Time left: 19.0 s
Objective value: 0.04567
x: 27.18
Iteration 6/1000:
Time elapsed: 0.1 s
Time left: 7.4 s
Objective value: 4.520e-04
x: 24.99
Done!
Termination message:
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
(array(3.17785951e-19), DeviceArray(24.99999999, dtype=float64))
>>> vs["x"] - target ** 2
DeviceArray(-5.63725067e-09, dtype=float64)
>>> objs
[array(3.3772234),
array(0.04567386),
array(0.03582296),
array(0.00014534),
array(5.18203996e-07),
array(6.81622668e-12),
array(3.17785951e-19)]
>>> xs
[DeviceArray(10., dtype=float64),
DeviceArray(27.18281828, dtype=float64),
DeviceArray(23.14312757, dtype=float64),
DeviceArray(24.87958747, dtype=float64),
DeviceArray(25.00719916, dtype=float64),
DeviceArray(24.99997389, dtype=float64),
DeviceArray(24.99999999, dtype=float64)]