-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_experiments.py
56 lines (44 loc) · 1.73 KB
/
run_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Script used to run experiments from one or more TOML config files."""
import sys
import uuid
import logging
from os.path import isfile, join
import toml
import torch
import pandas as pd
from batteryprobe.data import create_data_loader
from batteryprobe.loops import train, evaluate
from batteryprobe.models import Baseline, AutoRegressive
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if __name__ == "__main__":
for filepath in sys.argv[1:]:
assert isfile(filepath), f"{filepath} is not a file."
logging.info(f"Loading parameters from {filepath}")
params = toml.load(filepath)
params["log_dir"] = join(params["log_dir"], uuid.uuid4().hex)
# Data
logging.info(f"Reading and processing data from {params['data_path']}")
df = pd.read_csv(params["data_path"])
df = df.iloc[::params["skip_frequency"]]
train_dls, val_dl = create_data_loader(df, params)
target_col = params["features"].index("capacity")
# Baseline
logging.info("Evaluating baseline model")
baseline = Baseline(target_col)
base_score = evaluate(baseline, val_dl, target_col)
# Training
logging.info("Training model")
model = AutoRegressive(params)
train(model, (train_dls, val_dl), params)
# Evaluation
logging.info("Evaluating trained model")
if not params["debug"]:
model.load_state_dict(
torch.load(join(params["log_dir"], "model.pt"))
)
model.eval()
score = evaluate(model, val_dl, target_col=target_col)
else:
from batteryprobe.utils import plot_sample
plot_sample(train_dls[0], target_col, n=1, model=model)