Skip to content

Commit

Permalink
Markovian GPs
Browse files Browse the repository at this point in the history
  • Loading branch information
DanWaxman committed Feb 28, 2024
1 parent 893c9b2 commit f7016d7
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

experiments/data/*
96 changes: 95 additions & 1 deletion src/Lintel/gp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(

@property
def lengthscale(self):
return 1e-6 + jnp.minimum(softplus(self.transformed_lengthscale.value), 1e2)
return 1e-6 + jnp.minimum(softplus(self.transformed_lengthscale.value), 1e6)

@property
def sigma_f(self):
Expand Down Expand Up @@ -223,3 +223,97 @@ def train_op():
print(iter_idx, f_value)

self.update_training_set(self.X, self.y)


class MarkovianGP(objax.Module):
def __init__(self, lengthscale, sigma_f, sigma_n, C):
self.transformed_lengthscale = objax.TrainVar(
softplus_inv(jnp.asarray(lengthscale))
)
self.transformed_sigma_f = objax.TrainVar(softplus_inv(jnp.array(sigma_f)))
self.transformed_sigma_n = objax.TrainVar(softplus_inv(jnp.asarray(sigma_n)))
self.C = C
self.t_last = jnp.inf

self.m = jnp.zeros((3, 1))
self.P = self.Pinf

@property
def lengthscale(self):
return 1e-6 + jnp.minimum(softplus(self.transformed_lengthscale.value), 1e6)

@property
def sigma_f(self):
return 1e-6 + jnp.minimum(softplus(self.transformed_sigma_f.value), 1e2)

@property
def sigma_n(self):
return 1e-6 + jnp.minimum(softplus(self.transformed_sigma_n.value), 1e2)

@property
def H(self):
return jnp.array([[1, 0, 0]])

@property
def Pinf(self):
kappa = 5.0 / 3.0 * self.sigma_f**2 / self.lengthscale**2.0

return jnp.array(
[
[self.sigma_f**2, 0.0, -kappa],
[0.0, kappa, 0.0],
[-kappa, 0.0, 25.0 * self.sigma_f**2 / self.lengthscale**4.0],
]
)

def get_Q(self, A):
return self.Pinf - A @ self.Pinf @ A.T

def get_A(self, dt):
lam = jnp.sqrt(5.0) / self.lengthscale
dtlam = dt * lam
A = jnp.exp(-dtlam) * (
dt
* jnp.array(
[
[lam * (0.5 * dtlam + 1.0), dtlam + 1.0, 0.5 * dt],
[-0.5 * dtlam * lam**2, lam * (1.0 - dtlam), 1.0 - 0.5 * dtlam],
[
lam**3 * (0.5 * dtlam - 1.0),
lam**2 * (dtlam - 3),
lam * (0.5 * dtlam - 2.0),
],
]
)
+ jnp.eye(3)
)
return A

def predict(self, t_star):
dt = max(t_star - self.t_last, 0)
A_n = self.get_A(dt)
m_evolved = A_n @ self.m
P_evolved = A_n @ self.P @ A_n.T + self.get_Q(A_n)

m = self.H @ m_evolved + self.C
sigma2 = self.H @ P_evolved @ self.H.T + self.sigma_n**2

return m, sigma2, m_evolved, P_evolved

def update(self, t_star, y_star, m, sigma2, m_evolved, P_evolved):
k = jax.scipy.linalg.solve(sigma2, self.H @ P_evolved, assume_a="pos").T
self.m = m_evolved + k @ (y_star - m)
self.P = P_evolved - k @ self.H @ P_evolved
self.t_last = t_star

def reset_and_filter(self, t, y, mean):
self.C = mean
self.t_last = jnp.inf

self.m = jnp.zeros((3, 1))
self.P = self.Pinf
N = t.shape[0]

for n in range(N):
o = self.predict(t[n])
self.update(t[n], y[n], *o)
4 changes: 2 additions & 2 deletions src/Lintel/intel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def predict_and_update(
)

# Calculate what from Eq. (15)
whats = self.weights**self.alpha + 1e-6
whats = self.weights**self.alpha + 1e-4
whats = whats / np.sum(whats)

# Product of experts predictive distributions, Eqs. (21-22)
Expand All @@ -103,7 +103,7 @@ def predict_and_update(
self.yprime = []

# Update means if time since mean update is more than L
if self.t_since_mean_update > self.L:
if self.t_since_mean_update >= self.L:
for m in range(self.M):
self.gps[m].C = np.mean(self.y[-self.L :])
self.t_since_mean_update = 0
Expand Down

0 comments on commit f7016d7

Please sign in to comment.