This is my "lazy" wrapper around jax, intended to minimize extra work in setting up optimization for simple custom models. For more advanced deep-nn models, I'd recommend libraries like Haiku, Elegy, Flax, or Trax.
import laxy
import jax.numpy as jnp
def model(params, inputs):
out = inputs["x"] * params["m"] + params["b"]
loss = jnp.square(inputs["y"] - out)
return out, loss
opt = laxy.OPT(model, params={"a":1.0,"b":0.0})
opt.fit(inputs={"x":x,"y":y})
Examples:
-
How do I save/load weights?
# save weights = opt.get_params() jnp.save("weights.npy",weights) # load weights = jnp.load("weights.npy",allow_pickle=True) opt.set_params(weights)
-
Can I use neural networks in my model?
from jax.experimental import stax stax_layers = stax.serial(stax.Dense(5), stax.Elu, stax.Dense(1), stax.Dropout(0.5)) nn_params, nn_layers = laxy.STAX(stax_layers, input_shape=(None,10)) def model(params, inputs): out = nn_layers(params["nn"], inputs["x"], rng=inputs["key"]) + params["a"] loss = jnp.square(out - inputs["y"]).sum() return out, loss opt = laxy.OPT(model, params={"nn":nn_params,"a":1.0})
-
Can I use random variables?
A random key is automatically added to the
inputs
dict at each optimization step. The seed for this key is set atlaxy.OPT(model, seed=0)
def model(params, inputs): out = inputs["x"] * params["m"] + jax.random.normal(inputs["key"],(1,)) loss = jnp.square(inputs["y"] - out) return out, loss
More than one key?
def model(params, inputs): keys = jax.random.split(inputs["key"],2) out = inputs["x"] * params["m"] + jax.random.normal(keys[0],(1,)) loss = jnp.square(inputs["y"] - out) + jax.random.uniform(keys[1],(1,)) return out, loss
-
Can I freeze a subset of weights?
Freeze forever:
def model(params, inputs): out = inputs["x"] * params["m"] + laxy.freeze(params["b"]) loss = jnp.square(inputs["y"] - out) return out, loss
Conditional freeze:
def model(params, inputs): out = inputs["x"] * params["m"] + laxy.freeze_cond(inputs["freeze"],params["b"]) loss = jnp.square(inputs["y"] - out) return out, loss opt = laxy.OPT(model, params={"a":1.0,"b":0.0}) opt.fit(inputs={"x":x,"y":y,"freeze":True}) # freeze opt.fit(inputs={"x":x,"y":y,"freeze":False}) # unfreeze