Skip to content

Commit b73079d

Browse files
authored
Merge branch 'dev' into feature/CMAPSS
2 parents 95ee179 + b325796 commit b73079d

File tree

8 files changed

+69
-13
lines changed

8 files changed

+69
-13
lines changed

examples/dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def run_example():
6363
plt.ylabel('Voltage (V)')
6464
plt.show()
6565

66-
6766
# This allows the module to be executed directly
6867
if __name__=='__main__':
6968
run_example()

sphinx_config/datasets.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Datasets
2+
================================================================
3+
4+
The `prog_models` dataset subpackage is used to download labeled prognostics data for use in model building, analysis, or validation. Every dataset comes equipped with a `load_data` function which loads the specified data. Some datasets require a dataset number or id. This indicates the specific data to load from the larger dataset. The format of the data is specific to the dataset downloaded. Details of the specific datasets are summarized below:
5+
6+
.. contents::
7+
:backlinks: top
8+
9+
Variable Load Battery Data (nasa_battery)
10+
----------------------------------------------------
11+
.. autofunction:: prog_models.datasets.nasa_battery.load_data
12+
13+
14+
CMAPSS Jet Engine Data (nasa_cmapss)
15+
----------------------------------------------------
16+
.. autofunction:: prog_models.datasets.nasa_cmapss.load_data
17+

sphinx_config/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ If you are new to this package, see `<getting_started.html>`__.
1717
models
1818
prognostics_model
1919
simresult
20+
datasets
2021
exceptions
2122
ProgAlgs <https://nasa.github.io/prog_algs>
2223
dev_guide

src/prog_models/datasets/nasa_battery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,7 @@ def load_data(batt_id):
101101
]
102102

103103
return desc, result
104+
105+
def clear_cache():
106+
"""Clears the cache of downloaded data"""
107+
cache.clear()

src/prog_models/prognostics_model.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numbers import Number
77
import numpy as np
88
from copy import deepcopy
9+
import itertools
910
from warnings import warn
1011
from collections import abc, namedtuple
1112
from .sim_result import SimResult, LazySimResult
@@ -419,7 +420,7 @@ def __next_state(self, x, u, dt) -> dict:
419420
"""
420421

421422
# Calculate next state and add process noise
422-
next_state = self.apply_process_noise(self.next_state(x, u, dt))
423+
next_state = self.apply_process_noise(self.next_state(x, u, dt), dt)
423424

424425
# Apply Limits
425426
return self.apply_limits(next_state)
@@ -744,9 +745,10 @@ def simulate_to_threshold(self, future_loading_eqn, first_output = None, thresho
744745
raise ProgModelInputException("'dt' must be a number or function, was a {}".format(type(config['dt'])))
745746
if isinstance(config['dt'], Number) and config['dt'] < 0:
746747
raise ProgModelInputException("'dt' must be positive, was {}".format(config['dt']))
747-
if not isinstance(config['save_freq'], Number):
748+
if not isinstance(config['save_freq'], Number) and not isinstance(config['save_freq'], tuple):
748749
raise ProgModelInputException("'save_freq' must be a number, was a {}".format(type(config['save_freq'])))
749-
if config['save_freq'] <= 0:
750+
if (isinstance(config['save_freq'], Number) and config['save_freq'] <= 0) or \
751+
(isinstance(config['save_freq'], tuple) and config['save_freq'][1] <= 0):
750752
raise ProgModelInputException("'save_freq' must be positive, was {}".format(config['save_freq']))
751753
if not isinstance(config['save_pts'], abc.Iterable):
752754
raise ProgModelInputException("'save_pts' must be list or array, was a {}".format(type(config['save_pts'])))
@@ -795,9 +797,20 @@ def check_thresholds(thresholds_met):
795797
saved_states = []
796798
saved_outputs = []
797799
saved_event_states = []
798-
save_freq = config['save_freq']
799800
horizon = t+config['horizon']
800-
next_save = t+save_freq
801+
if isinstance(config['save_freq'], tuple):
802+
# Tuple used to specify start and frequency
803+
t_step = config['save_freq'][1]
804+
# Use starting time or the next multiple
805+
t_start = config['save_freq'][0]
806+
start = max(t_start, t - (t-t_start)%t_step)
807+
iterator = itertools.count(start, t_step)
808+
else:
809+
# Otherwise - start is t0
810+
t_step = config['save_freq']
811+
iterator = itertools.count(t, t_step)
812+
next(iterator) # Skip current time
813+
next_save = next(iterator)
801814
save_pt_index = 0
802815
save_pts = config['save_pts']
803816
save_pts.append(1e99) # Add last endpoint
@@ -839,13 +852,15 @@ def next_time(t, x):
839852

840853
while t < horizon:
841854
dt = next_time(t, x)
842-
t = t + dt
855+
t = t + dt/2
856+
# Use state at midpoint of step to best represent the load during the duration of the step
843857
u = future_loading_eqn(t, x)
858+
t = t + dt/2
844859
x = next_state(x, u, dt)
845860

846861
# Save if at appropriate time
847862
if (t >= next_save):
848-
next_save += save_freq
863+
next_save = next(iterator)
849864
update_all()
850865
if (t >= save_pts[save_pt_index]):
851866
save_pt_index += 1

src/prog_models/utils/containers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ class DictLikeMatrixWrapper():
1818
The contained data (e.g., input, state, output). If numpy array should be column vector in same order as keys
1919
"""
2020
def __init__(self, keys, data):
21-
self._keys = keys
21+
self._keys = keys.copy()
2222
if isinstance(data, np.matrix):
2323
self.matrix = np.array(data)
2424
elif isinstance(data, np.ndarray):
2525
self.matrix = data
26-
elif isinstance(data, dict):
26+
elif isinstance(data, dict) or isinstance(data, DictLikeMatrixWrapper):
2727
self.matrix = np.array([[data[key]] for key in keys])
2828
else:
2929
raise ProgModelTypeError(f"Input must be a dictionary or numpy array, not {type(data)}")
@@ -34,6 +34,13 @@ def __getitem__(self, key):
3434
def __setitem__(self, key, value):
3535
self.matrix[self._keys.index(key)] = np.atleast_1d(value)
3636

37+
def __delitem__(self, key):
38+
self.matrix = np.delete(self.matrix, self._keys.index(key), axis=0)
39+
self._keys.remove(key)
40+
41+
def __add__(self, other):
42+
return DictLikeMatrixWrapper(self._keys, self.matrix + other.matrix)
43+
3744
def __iter__(self):
3845
return iter(self._keys)
3946

tests/test_centrifugal_pump.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44
from prog_models.models.centrifugal_pump import CentrifugalPump, CentrifugalPumpWithWear, CentrifugalPumpBase
55

6+
67
class TestCentrifugalPump(unittest.TestCase):
78
def test_centrifugal_pump_base(self):
89
pump = CentrifugalPumpBase(process_noise= 0)
@@ -73,7 +74,7 @@ def future_loading(t, x=None):
7374
pump.parameters['wA'] = 1e-2
7475
pump.parameters['wThrust'] = 1e-10
7576
(times, inputs, states, outputs, event_states) = pump.simulate_to_threshold(future_loading, pump.output(pump.initialize(future_loading(0),{})))
76-
self.assertAlmostEqual(times[-1], 23891)
77+
self.assertAlmostEqual(times[-1], 23892)
7778

7879
def test_centrifugal_pump_with_wear(self):
7980
pump = CentrifugalPumpWithWear(process_noise= 0)
@@ -148,7 +149,7 @@ def future_loading(t, x=None):
148149
pump.parameters['x0']['wA'] = 1e-2
149150
pump.parameters['x0']['wThrust'] = 1e-10
150151
(times, inputs, states, outputs, event_states) = pump.simulate_to_threshold(future_loading, pump.output(pump.initialize(future_loading(0),{})))
151-
self.assertAlmostEqual(times[-1], 23891)
152+
self.assertAlmostEqual(times[-1], 23892)
152153

153154
# Check warning when changing overwritten Parameters
154155
with self.assertWarns(UserWarning):

tests/test_dict_like_matrix_wrapper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ def _checks(self, c1):
5454
self.assertEqual(c1['d'], 7)
5555
self.assertListEqual(c1.keys(), ['a', 'b', 'c', 'd'])
5656

57+
# deleting items
58+
del c1['a']
59+
self.assertTrue((c1.matrix == np.array([[2], [5], [7]])).all())
60+
self.assertListEqual(c1.keys(), ['b', 'c', 'd'])
61+
del c1['c']
62+
self.assertTrue((c1.matrix == np.array([[2], [7]])).all())
63+
self.assertListEqual(c1.keys(), ['b', 'd'])
64+
del c1['d']
65+
del c1['b']
66+
self.assertTrue((c1.matrix == np.array([[]])).all())
67+
self.assertListEqual(c1.keys(), [])
68+
5769
def test_dict_init(self):
5870
c1 = DictLikeMatrixWrapper(['a', 'b'], {'a': 1, 'b': 2})
5971
self._checks(c1)
@@ -77,7 +89,7 @@ def run_tests():
7789
def main():
7890
l = unittest.TestLoader()
7991
runner = unittest.TextTestRunner()
80-
print("\n\nTesting Base Models")
92+
print("\n\nTesting Containers")
8193
result = runner.run(l.loadTestsFromTestCase(TestDictLikeMatrixWrapper)).wasSuccessful()
8294

8395
if not result:

0 commit comments

Comments
 (0)