Skip to content

Commit 89265ba

Browse files
authored
Merge pull request Kaggle#433 from marketneutral/pykalman
Add pykalman and related tests
2 parents dab6b38 + 69dbd81 commit 89265ba

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ RUN pip install flashtext && \
476476
pip install pytext-nlp && \
477477
pip install tsfresh && \
478478
pip install pymagnitude && \
479+
pip install pykalman && \
479480
/tmp/clean-layer.sh
480481

481482
# Pin Vowpal Wabbit v8.6.0 because 8.6.1 does not build or install successfully

tests/test_pykalman.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
import numpy as np
3+
from pykalman import KalmanFilter
4+
from pykalman import UnscentedKalmanFilter
5+
from pykalman.sqrt import CholeskyKalmanFilter, AdditiveUnscentedKalmanFilter
6+
7+
class TestPyKalman(unittest.TestCase):
8+
def test_kalman_filter(self):
9+
kf = KalmanFilter(transition_matrices = [[1, 1], [0, 1]], observation_matrices = [[0.1, 0.5], [-0.3, 0.0]])
10+
measurements = np.asarray([[1,0], [0,0], [0,1]]) # 3 observations
11+
kf = kf.em(measurements, n_iter=5)
12+
(filtered_state_means, filtered_state_covariances) = kf.filter(measurements)
13+
(smoothed_state_means, smoothed_state_covariances) = kf.smooth(measurements)
14+
return filtered_state_means
15+
16+
def test_kalman_missing(self):
17+
kf = KalmanFilter(transition_matrices = [[1, 1], [0, 1]], observation_matrices = [[0.1, 0.5], [-0.3, 0.0]])
18+
measurements = np.asarray([[1,0], [0,0], [0,1]]) # 3 observations
19+
measurements = np.ma.asarray(measurements)
20+
measurements[1] = np.ma.masked
21+
kf = kf.em(measurements, n_iter=5)
22+
(filtered_state_means, filtered_state_covariances) = kf.filter(measurements)
23+
(smoothed_state_means, smoothed_state_covariances) = kf.smooth(measurements)
24+
return filtered_state_means
25+
26+
def test_unscented_kalman(self):
27+
ukf = UnscentedKalmanFilter(lambda x, w: x + np.sin(w), lambda x, v: x + v, transition_covariance=0.1)
28+
(filtered_state_means, filtered_state_covariances) = ukf.filter([0, 1, 2])
29+
(smoothed_state_means, smoothed_state_covariances) = ukf.smooth([0, 1, 2])
30+
return filtered_state_means
31+
32+
def test_online_update(self):
33+
kf = KalmanFilter(transition_matrices = [[1, 1], [0, 1]], observation_matrices = [[0.1, 0.5], [-0.3, 0.0]])
34+
measurements = np.asarray([[1,0], [0,0], [0,1]]) # 3 observations
35+
measurements = np.ma.asarray(measurements)
36+
measurements[1] = np.ma.masked # measurement at timestep 1 is unobserved
37+
kf = kf.em(measurements, n_iter=5)
38+
(filtered_state_means, filtered_state_covariances) = kf.filter(measurements)
39+
for t in range(1, 3):
40+
filtered_state_means[t], filtered_state_covariances[t] = \
41+
kf.filter_update(filtered_state_means[t-1], filtered_state_covariances[t-1], measurements[t])
42+
return filtered_state_means
43+
44+
def test_robust_sqrt(self):
45+
kf = CholeskyKalmanFilter(transition_matrices = [[1, 1], [0, 1]], observation_matrices = [[0.1, 0.5], [-0.3, 0.0]])
46+
ukf = AdditiveUnscentedKalmanFilter(lambda x, w: x + np.sin(w), lambda x, v: x + v, observation_covariance=0.1)
47+

0 commit comments

Comments
 (0)