-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathval.py
36 lines (26 loc) · 1010 Bytes
/
val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
from datasets.taxibj import data_permute
def compute_errors(preds, y_true, mmn):
preds = mmn.inverse_transform(preds)
y_true = mmn.inverse_transform(y_true)
pred_mean = preds[:, 0:2]
diff = y_true - pred_mean
mse = np.mean(diff ** 2)
rmse = np.sqrt(mse)
mae = np.mean(np.abs(diff))
return mse, mae, rmse
def valid(model, val_dataloader, mmn, device):
model.to(device)
model.eval()
rmse_list, mse_list, mae_list = [], [], []
for i, (X_c, X_p, X_t, X_meta, labels) in enumerate(val_dataloader):
X_c, X_p, X_t, X_meta, labels = data_permute(X_c, X_p, X_t, X_meta, labels, device)
outputs = model(X_c, X_p, X_t, X_meta)
mse, mae, rmse = compute_errors(outputs.cpu().data.numpy(), labels.cpu().data.numpy(), mmn)
rmse_list.append(rmse)
mse_list.append(mse)
mae_list.append(mae)
rmse = np.mean(rmse_list)
mse = np.mean(mse_list)
mae = np.mean(mae_list)
return rmse, mse, mae