Skip to content

Commit

Permalink
Merge pull request #239 from nasa/bug/horizon_test
Browse files Browse the repository at this point in the history
Bug/horizon test - fixes bug in horizon hotfix and adds tests of horizon feature
  • Loading branch information
kjjarvis committed Aug 17, 2023
2 parents 38fb4d3 + 200135d commit f73a512
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 8 deletions.
8 changes: 4 additions & 4 deletions examples/horizon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run_example():
# Step 1: Setup model & future loading
def future_loading(t, x = None):
return {}
m = ThrownObject(process_noise = 0.25, measurement_noise = 0.2)
m = ThrownObject(process_noise = 0.2, measurement_noise = 0.1)
initial_state = m.initialize()

# Step 2: Demonstrating state estimator
Expand Down Expand Up @@ -50,17 +50,17 @@ def future_loading(t, x = None):
# THIS IS WHERE WE DIVERGE FROM THE THROWN_OBJECT_EXAMPLE
# Here we set a prediction horizon
# We're saying we are not interested in any events that occur after this time
PREDICTION_HORIZON = 7.75
PREDICTION_HORIZON = 7.67
samples = filt.x # Since we're using a particle filter, which is also sample-based, we can directly use the samples, without changes
STEP_SIZE = 0.01
STEP_SIZE = 0.001
mc_results = mc.predict(samples, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON)
print("\nPredicted Time of Event:")
metrics = mc_results.time_of_event.metrics()
pprint(metrics) # Note this takes some time
mc_results.time_of_event.plot_hist(keys = 'impact')
mc_results.time_of_event.plot_hist(keys = 'falling')

print("\nSamples where impact occurs before horizon: {:.2f}%".format(metrics['impact']['number of samples']/NUM_SAMPLES*100))
print("\nSamples where impact occurs before horizon: {:.2f}%".format(metrics['impact']['number of samples']/mc.parameters['n_samples']*100))

# Step 4: Show all plots
import matplotlib.pyplot as plt # For plotting
Expand Down
8 changes: 4 additions & 4 deletions src/prog_algs/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
**params
)
else:
# Since horizon is relative to t0 (the simulation starting point),
# we must subtract the difference in current t0 from the initial (i.e., prediction t0)
# each subsequent simulation
params['horizon'] = HORIZON - (params['t0'] - t0)

# Simulate
events_remaining = params['events'].copy()
Expand All @@ -111,6 +107,10 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)

# Non-vectorized prediction
while len(events_remaining) > 0: # Still events to predict
# Since horizon is relative to t0 (the simulation starting point),
# we must subtract the difference in current t0 from the initial (i.e., prediction t0)
# each subsequent simulation
params['horizon'] = HORIZON - (params['t0'] - t0)
(t, u, xi, z, es) = simulate_to_threshold(future_loading_eqn,
first_output,
threshold_keys = events_remaining,
Expand Down
6 changes: 6 additions & 0 deletions tests/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .test_metrics import run_tests as metrics_main
from .test_visualize import run_tests as visualize_main
from .test_tutorials import run_tests as tutorials_main
from .test_horizon import run_tests as horizon_main

import unittest
import sys
Expand Down Expand Up @@ -69,5 +70,10 @@ def run_basic_ex():
except Exception:
was_successful = False

try:
horizon_main_main()
except Exception:
was_successful = False

if not was_successful:
raise Exception('Tests Failed')

Check failure on line 79 in tests/__main__.py

View workflow job for this annotation

GitHub Actions / test-prog_models-released (3.8)

Tests Failed
75 changes: 75 additions & 0 deletions tests/test_horizon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from io import StringIO
import sys
import unittest

from prog_models.models.thrown_object import ThrownObject
from prog_algs import *

class TestHorizon(unittest.TestCase):
def setUp(self):
# set stdout (so it won't print)
sys.stdout = StringIO()

def tearDown(self):
sys.stdout = sys.__stdout__

def test_horizon_ex(self):
# Setup model
m = ThrownObject(process_noise = 0.25, measurement_noise = 0.2)
# Change parameters (to make simulation faster)
m.parameters['thrower_height'] = 1.0
m.parameters['throwing_speed'] = 10.0
initial_state = m.initialize()

# Define future loading (necessary for prediction call)
def future_loading(t, x = None):
return {}

# Setup Predictor (smaller sample size for efficiency)
mc = predictors.MonteCarlo(m)
mc.parameters['n_samples'] = 50

# Perform a prediction
# With this horizon, all samples will reach 'falling', but only some will reach 'impact'
PREDICTION_HORIZON = 2.127
STEP_SIZE = 0.001
mc_results = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON)

# 'falling' happens before the horizon is met
falling_res = [mc_results.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['falling'] is not None]
self.assertEqual(len(falling_res), mc.parameters['n_samples'])

# 'impact' happens around the horizon, so some samples have reached this event and others haven't
impact_res = [mc_results.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['impact'] is not None]
self.assertLess(len(impact_res), mc.parameters['n_samples'])

# Try again with very low prediction_horizon, where no events are reached
# Note: here we count how many None values there are for each event (in the above and below examples, we count values that are NOT None)
mc_results_no_event = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon = 0.3)
falling_res_no_event = [mc_results_no_event.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results_no_event.time_of_event[iter]['falling'] is None]
impact_res_no_event = [mc_results_no_event.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results_no_event.time_of_event[iter]['impact'] is None]
self.assertEqual(len(falling_res_no_event), mc.parameters['n_samples'])
self.assertEqual(len(impact_res_no_event), mc.parameters['n_samples'])

# Finally, try without horizon, all events should be reached for all samples
mc_results_all_event = mc.predict(initial_state, future_loading, dt=STEP_SIZE)
falling_res_all_event = [mc_results_all_event.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results_all_event.time_of_event[iter]['falling'] is not None]
impact_res_all_event = [mc_results_all_event.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results_all_event.time_of_event[iter]['impact'] is not None]
self.assertEqual(len(falling_res_all_event), mc.parameters['n_samples'])
self.assertEqual(len(impact_res_all_event), mc.parameters['n_samples'])

# This allows the module to be executed directly
def run_tests():
unittest.main()

def main():
load_test = unittest.TestLoader()
runner = unittest.TextTestRunner()
print("\n\nTesting Horizon functionality")
result = runner.run(load_test.loadTestsFromTestCase(TestHorizon)).wasSuccessful()

if not result:
raise Exception("Failed test")

if __name__ == '__main__':
main()

0 comments on commit f73a512

Please sign in to comment.