-
Notifications
You must be signed in to change notification settings - Fork 4
/
visualizations.py
82 lines (60 loc) · 2.88 KB
/
visualizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch.optim
from models import Lusch
from data_generator import load_dataset,differential_dataset
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch.nn.functional as F
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
koopman_dim = 64
hidden_dim = 500
input_dim = 3
delta_t = 0.01
batch_size=256
horizon = 72;
load_chkpt = True
chkpt_filename = "best_fixed_matrix"
start_epoch = 1
device="cuda"
n = 2
model = Lusch(input_dim,koopman_dim,hidden_dim = hidden_dim,delta_t=delta_t,device=device).to(device)
X_train, X_test = load_dataset(chunk_size=1)
X_train_recon = X_train[:, :-horizon, :]; X_test_recon = X_test[:, :-horizon, :]
X_forecast_train = X_train[:, -horizon:, :] ; X_forecast_test = X_test[:, -horizon:, :]
train_dl = DataLoader(differential_dataset(X_train_recon, horizon), batch_size=batch_size)
test_dl = DataLoader(differential_dataset(X_test_recon, horizon), batch_size=batch_size)
model.mu = train_dl.dataset.mu.to(device)
model.std = train_dl.dataset.std.to(device)
if load_chkpt:
print("LOAD CHECKPOINTS")
state_dicts = torch.load(chkpt_filename+".pth")
model.load_state_dict(state_dicts['model'])
with torch.inference_mode():
model.eval()
# x_recon_hat = model.recover( model.embed(X_test_recon[[n],:,:].to(device)) ).cpu().squeeze(0)
x_recon_hat = model(X_test_recon[[n],:,:].to(device)).cpu().squeeze(0)
# print(F.mse_loss(model(X_test_recon.cuda()),X_test_recon.cuda()))
x_ahead_hat = model.recover(model.koopman_operator(model.embed(X_test_recon[[n],[-1],:].to(device).unsqueeze(0)),horizon)).cpu().squeeze(0)
mpl.use('Qt5Agg')
plt.figure(figsize=(20, 10))
# for i in range(3):
plt.plot(np.arange(X_test.shape[1]), X_test[n, :, :], '-')
plt.plot(np.arange(x_recon_hat.shape[0]), x_recon_hat,'--')
plt.plot(X_test_recon.shape[1] + np.arange(horizon), x_ahead_hat.cpu(), 'r.')
plt.xlabel("Time (n)", fontsize=20)
plt.ylabel("State", fontsize=20)
plt.legend(["x", "y", "z", "$x_{reconstructed}$", "$y_{reconstructed}$", "$z_{reconstructed}$", "Forecasted"],
fontsize=20)
plt.show()
plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(X_test[n, :, 0], X_test[n, :, 1], X_test[n, :, 2], 'k-') # c=np.linspace(0,1,Time_Length))
ax.plot3D(x_recon_hat[:, 0], x_recon_hat[:, 1], x_recon_hat[:, 2], 'b*')
ax.plot3D(x_ahead_hat[:, 0], x_ahead_hat[:, 1], x_ahead_hat[:, 2], 'rx')
ax.set_xlabel('$X$', fontsize=20)
ax.set_ylabel('$Y$', fontsize=20)
ax.set_zlabel(r'$Z$', fontsize=20)
plt.legend(["Actual", "Reconstruction", "Forecasted"])
plt.show()