Differentiable QP solver in JAX.
This package can be used for solving convex quadratic programs of the following form:
where jit
and vmap
functionality, as well as differentiated with reverse-mode grad
.
The QP is solved with a primal-dual interior point algorithm detailed in cvxgen, with the solution to the linear systems computed with reduction techniques from cvxopt. At an approximate primal-dual solution, the the primal variable
To install directly from github using pip
:
$ pip install qpax
Alternatively, to install from source in editable mode:
$ pip install -e .
The solver tolerance (solver_tol
) should be something reasonable given the available precision. With 32bit precision (the default in JAX), solver_tol
should be greater than 1e-5
.
Precision | Tolerance |
---|---|
jnp.float32 |
solver_tol [1e-5, 1e-2]
|
jnp.float64 |
solver_tol [1e-12, 1e-2]
|
In order to enable 64bit precision, you can do the following at startup:
# again, this only works on startup!
import jax
jax.config.update("jax_enable_x64", True)
This is taken from the JAX - The Sharp Bits.
We can solve QPs with qpax in a way that plays nice with JAX's jit
and vmap
:
import qpax
# solve QP (this can be combined with jit or vmap)
x, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h, solver_tol=1e-6)
Here let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from qpax
, first is the ability to solve QPs without any equality constraints, and second is the ability to vmap
over a batch of QPs.
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import qpax
import timeit
"""
solve batched non-negative least squares (nnls) problems
min_x |Fx - g|^2
st x >= 0
"""
n = 5 # size of x
m = 10 # rows in F
# create data for N_qps random nnls problems
N_qps = 10000
Fs = jnp.array(np.random.randn(N_qps, m, n))
gs = jnp.array(np.random.randn(N_qps, m))
@jit
def form_qp(F, g):
# convert the least squares to qp form
n = F.shape[1]
Q = F.T @ F
q = -F.T @ g
G = -jnp.eye(n)
h = jnp.zeros(n)
A = jnp.zeros((0, n))
b = jnp.zeros(0)
return Q, q, A, b, G, h
# create the QPs in a batched fashion
Qs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)
# create function for solving a batch of QPs
batch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))
xs = batch_qp(Qs, qs, As, bs, Gs, hs)
Alternatively, if we are only looking to use the primal variable x
, we can use solve_qp_primal
which enables automatic differentiation:
import jax
import jax.numpy as jnp
import qpax
def loss(Q, q, A, b, G, h):
x = qpax.solve_qp_primal(Q, q, A, b, G, h, solver_tol=1e-4, target_kappa=1e-3)
x_bar = jnp.ones(len(q))
return jnp.dot(x - x_bar, x - x_bar)
# gradient of loss function
loss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))
# compatible with jit
loss_grad_jit = jax.jit(loss_grad)
# calculate derivatives
derivs = loss_grad_jit(Q, q, A, b, G, h)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs
where target_kappa
is used to determine how much smoothing should be applied to the gradients through solve_qp_primal
. For more detail on target_kappa
, please refer to the paper.
@misc{tracy2024differentiability,
title={On the Differentiability of the Primal-Dual Interior-Point Method},
author={Kevin Tracy and Zachary Manchester},
year={2024},
eprint={2406.11749},
archivePrefix={arXiv},
primaryClass={math.OC}
}