-
Notifications
You must be signed in to change notification settings - Fork 11
/
mnist_with_declarative_trainer.py
122 lines (111 loc) · 5.41 KB
/
mnist_with_declarative_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch.nn.functional as F
from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler
from ignite.handlers import ModelCheckpoint
from ignite.metrics import Accuracy, Loss
from torch import nn
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor
from pipelinex import NetworkTrain
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
if __name__ == "__main__":
train_params = {
# loss_fn accepts a loss function at https://pytorch.org/docs/stable/nn.html#loss-functions
"loss_fn": nn.NLLLoss(),
# epochs accepts an integer
"epochs": 2,
# [Optional] seed (random seed) accepts an integer
"seed": 42,
# optimizer accepts optimizers at https://pytorch.org/docs/stable/optim.html
"optimizer": SGD,
# optimizer_params accepts parameters for the specified optimizer
"optimizer_params": {"lr": 0.01, "momentum": 0.5},
# train_data_loader_params accepts args at https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
"train_data_loader_params": {"batch_size": 64, "num_workers": 0},
# val_data_loader_params accepts args at https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
"val_data_loader_params": {"batch_size": 1000, "num_workers": 0},
# [Optional] evaluation_metrics accepts dict of metrics at https://pytorch.org/ignite/metrics.html
"evaluation_metrics": {
"accuracy": Accuracy(),
"loss": Loss(loss_fn=nn.NLLLoss()),
},
# [Optional] evaluate_train_data (when to compute evaluation metrics using train_dataset)
# accepts events at https://pytorch.org/ignite/engine.html#ignite.engine.Events
"evaluate_train_data": "EPOCH_COMPLETED",
# [Optional] evaluate_val_data (when to compute evaluation metrics using val_dataset)
# accepts events at https://pytorch.org/ignite/engine.html#ignite.engine.Events
"evaluate_val_data": "EPOCH_COMPLETED",
# [Optional] progress_update (whether to show progress bar using tqdm package) accepts bool
"progress_update": True,
# [Optional] param scheduler at
# https://pytorch.org/ignite/contrib/handlers.html#module-ignite.contrib.handlers.param_scheduler
"scheduler": LinearCyclicalScheduler,
# [Optional] scheduler_params accepts parameters for the specified scheduler
"scheduler_params": {
"param_name": "lr",
"start_value": 0.001,
"end_value": 0.01,
"cycle_epochs": 2,
"cycle_mult": 1.0,
"start_value_mult": 1.0,
"end_value_mult": 1.0,
"save_history": False,
},
# [Optinal] ModelCheckpoint accepts a ModelCheckpoint at
# https://pytorch.org/ignite/handlers.html#ignite.handlers.ModelCheckpoint
"model_checkpoint": ModelCheckpoint,
# [Optional] parameters for ModelCheckpoint
"model_checkpoint_params": {
"dirname": "../checkpoint",
"filename_prefix": "model",
"save_interval": None,
"n_saved": 1,
"atomic": True,
"require_empty": False,
"create_dir": True,
"save_as_state_dict": True,
},
# [Optional] parameters for flexible version of EarlyStopping at
# https://pytorch.org/ignite/handlers.html#ignite.handlers.EarlyStopping
"early_stopping_params": {
# metric (metric to monitor to determine whether to stop early) accepts str
"metric": "loss",
# minimize (if set to True, smaller metric value is considered better) accepts bool
"minimize": True,
# a parameter for ignite.handlers.EarlyStopping
"patience": 1000,
# a parameter for ignite.handlers.EarlyStopping
"min_delta": 0.0,
# a parameter for ignite.handlers.EarlyStopping
"cumulative_delta": False,
},
# [Optional] time_limit (time limit for training in seconds) accepts an integer
"time_limit": 3600,
# [Optional] train_dataset_size_limit accepts an integer
# "train_dataset_size_limit": 128,
# [Optional] val_dataset_size_limit accepts an integer
# "val_dataset_size_limit": 2000,
# [Optional] mlflow_logging: If True and MLflow is installed, MLflow logging is enabled.
"mlflow_logging": False,
}
nn_train = NetworkTrain(**train_params)
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST(download=True, root=".", transform=data_transform, train=True)
val_dataset = MNIST(download=False, root=".", transform=data_transform, train=False)
initial_model = Net()
trained_model = nn_train(initial_model, train_dataset, val_dataset)