Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Aug 17, 2023
1 parent f73a512 commit dcfd6ee
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/prog_algs/state_estimators/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def __init__(self, model, x0, **kwargs):
# Added to avoid float/int issues
self.parameters['num_particles'] = int(self.parameters['num_particles'])
sample_gen = x0.sample(self.parameters['num_particles'])
samples = [array(sample_gen.key(k), dtype=float64) for k in x0.keys()]
samples = {k: array(sample_gen.key(k), dtype=float64) for k in x0.keys()}

self.particles = model.StateContainer(array(samples, dtype=float64))
self.particles = model.StateContainer(samples)

if 'R' in self.parameters:
# For backwards compatibility
Expand Down
17 changes: 16 additions & 1 deletion tests/test_state_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random

from prog_models import PrognosticsModel, LinearModel
from prog_models.models import ThrownObject, BatteryElectroChem, PneumaticValveBase
from prog_models.models import ThrownObject, BatteryElectroChem, PneumaticValveBase, BatteryElectroChemEOD
from prog_algs.state_estimators import ParticleFilter, KalmanFilter, UnscentedKalmanFilter
from prog_algs.uncertain_data import ScalarData, MultivariateNormalDist, UnweightedSamples

Expand Down Expand Up @@ -557,6 +557,21 @@ def future_loading(t, x=None):
times = simulation_result.times
for t, u, z in zip(times, inputs.data, outputs.data):
kf.estimate(t, u, z)

def test_PF_particle_ordering(self):
"""
This is testing for a bug found by @mstraut where particle filter was mixing up the keys if users:
1. Do not call m.initialize(), and instead
2. provide a state as a dictionary instead of a state container, and
3. order the states in a different order than m.states
"""
m = BatteryElectroChemEOD()
x0 = m.parameters['x0'] # state as a dictionary with the wrong order
filt = ParticleFilter(m, x0, num_particles=2)
for key in m.states:
self.assertEqual(filt.particles[key][0], x0[key])
self.assertEqual(filt.particles[key][1], x0[key])


# This allows the module to be executed directly
def run_tests():
Expand Down

0 comments on commit dcfd6ee

Please sign in to comment.