Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gaussian Sum Filter to main #122

Closed
wants to merge 9 commits into from
Prev Previous commit
Next Next commit
gsf working fine
  • Loading branch information
syedshabbirahmed committed May 16, 2024
commit 8e3018e3f0d2bccd01d26004fa2a5458dfc65d14
52 changes: 21 additions & 31 deletions examples/ex_gsf_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from typing import List


"""
This example runs an Interacting Multiple Model filter to estimate the process model noise matrix
for a state that is on a Lie group. The performance is compared to an EKF that knows the ground
Expand Down Expand Up @@ -53,51 +54,40 @@ def gsf_trial(trial_number: int) -> List[nav.GaussianResult]:
"""
np.random.seed(trial_number)
state_true, input_list, meas_list = dg.generate(x0, 0, t_max, True)

x0_check = x0.plus(nav.randvec(P0))

# Initial state estimates
x = [SE2State([0, -5, 0], stamp=0.0),
SE2State([0, 5, 0], stamp=0.0)]
x = [x_.plus(nav.randvec(P0)) for x_ in x]

weights = [1, 1]
x0_check = nav.gsf.GMMState(
[nav.StateWithCovariance(_x, P0) for _x in x], weights
)

estimate_list = nav.gsf.run_gsf_filter(
gsf, x0_check, P0, input_list, meas_list
)

results = [
nav.imm.IMMResult(estimate_list[i], state_true[i])
nav.gsf.GSFResult(estimate_list[i], state_true[i])
for i in range(len(estimate_list))
]

return nav.imm.IMMResultList(results)
return nav.gsf.GSFResultList(results)

N = 1
results = nav.monte_carlo(gsf_trial, N)
N = 2
results = gsf_trial(0)

if __name__ == "__main__":
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")

fig, ax = plt.subplots(1, 1)
ax.plot(results.stamp, results.average_nees, label="IMM NEES")
ax.plot(
results.stamp, results.expected_nees, color="r", label="Expected NEES"
)
ax.plot(
results.stamp,
results.nees_lower_bound(0.99),
color="k",
linestyle="--",
label="99 percent c.i.",
)
ax.plot(
results.stamp,
results.nees_upper_bound(0.99),
color="k",
linestyle="--",
)
ax.set_title("{0}-trial average NEES".format(results.num_trials))
ax.set_ylim(0, None)
ax.set_xlabel("Time (s)")
ax.set_ylabel("NEES")
ax.legend()

fig, ax = nav.plot_error(results)
ax[0].set_title("Error plots")
ax[0].set_ylabel("Error (rad)")
ax[1].set_ylabel("Error (m)")
ax[2].set_ylabel("Error (m)")
ax[2].set_xlabel("Time (s)")
plt.show()
54 changes: 36 additions & 18 deletions navlie/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tqdm import tqdm
from navlie.filters import ExtendedKalmanFilter
from scipy.stats import multivariate_normal
from navlie.utils import GaussianResultList, GaussianResult, state_interp
from navlie.utils import GaussianResult, GaussianResultList
from navlie.imm import gaussian_mixing

class GMMState:
Expand All @@ -34,7 +34,7 @@ def stamp(self):

def copy(self) -> "GMMState":
x_copy = [x.copy() for x in self.states]
return GMMState(x_copy, self.weights.copy())
return GMMState(x_copy, self.weights.copy() / np.sum(self.weights))

class GSFResult(GaussianResult):
__slots__ = [
Expand All @@ -57,9 +57,31 @@ def __init__(self, gsf_estimate: GMMState, state_true: State):
),
state_true,
)

self.weights = gsf_estimate.weights

class GSFResultList(GaussianResultList):
__slots__ = [
"stamp",
"state",
"state_true",
"covariance",
"error",
"ees",
"nees",
"md",
"three_sigma",
"value",
"value_true",
"dof",
"weights",
]

def __init__(self, result_list: List[GSFResult]):
super().__init__(result_list)
self.weights = np.array(
[r.weights for r in result_list]
)


class GaussianSumFilter:
"""
Expand Down Expand Up @@ -126,8 +148,8 @@ def predict(
def correct(
self,
x: GMMState,
u: Input,
y: Measurement,
u: Input,
) -> GMMState:
"""
Corrects the state estimate using a measurement. The user must provide
Expand All @@ -145,27 +167,29 @@ def correct(
GMMState
Corrected states with associated weights.
"""
x_check = x.copy()
n_modes = len(x.states)
weights_check = x.weights.copy()

x_hat = []
weights_hat = []
weights_hat = np.zeros(n_modes)
for i in range(n_modes):
x, details_dict = self.ekf.correct(x.states[i], y, u)
x, details_dict = self.ekf.correct(x_check.states[i], y, u,
output_details=True)
x_hat.append(x)
z = details_dict["z"]
S = details_dict["S"]
model_likelihood = multivariate_normal.pdf(
z.ravel(), mean=np.zeros(z.shape), cov=S
)
weights_hat.append(weights_check[i] * model_likelihood)
weights_hat[i] = weights_check[i] * model_likelihood

# If all model likelihoods are zero to machine tolerance, np.sum(mu_k)=0 and it fails
# Add this fudge factor to get through those cases.
if np.allclose(weights_hat, np.zeros(weights_hat.shape)):
weights_hat = 1e-10 * np.ones(weights_hat.shape)

weights_hat = np.array(weights_hat) / np.sum(weights_hat)
weights_hat = weights_hat / np.sum(weights_hat)

return GMMState(x_hat, weights_hat)

Expand Down Expand Up @@ -195,8 +219,8 @@ def run_gsf_filter(
meas_data : List[Measurement]
_description_
"""
x = StateWithCovariance(x0, P0)
if x.state.stamp is None:
x = x0.copy()
if x.stamp is None:
raise ValueError("x0 must have a valid timestamp.")

# Sort the data by time
Expand All @@ -205,25 +229,19 @@ def run_gsf_filter(

# Remove all that are before the current time
for idx, u in enumerate(input_data):
if u.stamp >= x.state.stamp:
if u.stamp >= x.stamp:
input_data = input_data[idx:]
break

for idx, y in enumerate(meas_data):
if y.stamp >= x.state.stamp:
if y.stamp >= x.stamp:
meas_data = meas_data[idx:]
break

meas_idx = 0
if len(meas_data) > 0:
y = meas_data[meas_idx]

n_modes = 1
weights = [1.0]
x = GMMState(
[StateWithCovariance(x0, P0)] * n_modes, weights
)

results_list = []
for k in tqdm(range(len(input_data) - 1), disable=disable_progress_bar):
u = input_data[k]
Expand Down
Empty file added navlie/mixture_utils.py
Empty file.