Skip to content

Commit 68b047b

Browse files
Small refactoring
- split up big train function - train now only trains (and optionally validates) - "run"-function builds model, starts training and runs test - replace Namespace from argparse with excplicit parameters - I originally just used NameSpace since I expected many parameters to change and it would be tideous to constantly pass all paramters. Our code is now more stable so I think this shouldn't be a problem anymore. - The gain through this, is that functions can be reused more easily (for example in a hyperparameter-iteration script) - also make model parameterizable through constructor parameters - allow no validation-set by setting train-val-split to 1.0
1 parent e25e052 commit 68b047b

File tree

2 files changed

+73
-47
lines changed

2 files changed

+73
-47
lines changed

model.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def forward(self, x):
2626

2727

2828
class StringNet(nn.Module):
29-
def __init__(self, n_classes, seq_length, batch_size):
29+
def __init__(self, n_classes: int, seq_length: int, batch_size: int,
30+
lstm_hidden_dim: int = 100, bidirectional: bool = False, lstm_layers: int = 2,
31+
lstm_dropout: float = 0.5, fc2_dim: int = 100):
3032
"""
3133
In the constructor we instantiate two nn.Linear modules and assign them as
3234
member variables.
@@ -36,9 +38,9 @@ def __init__(self, n_classes, seq_length, batch_size):
3638
self.n_classes = n_classes
3739
self.seq_length = seq_length
3840
self.batch_size = batch_size
39-
self.hidden_dim = 100
40-
self.bidirectional = False
41-
self.lstm_layers = 2
41+
self.lstm_hidden_dim = lstm_hidden_dim
42+
self.bidirectional = bidirectional
43+
self.lstm_layers = lstm_layers
4244

4345
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=1, stride=1)
4446
self.bn1 = nn.BatchNorm2d(64)
@@ -60,20 +62,21 @@ def __init__(self, n_classes, seq_length, batch_size):
6062
# self.res_block7 = ResBlock(512, 512)
6163
# self.res_block8 = ResBlock(512, 512)
6264

63-
self.lstm_forward = nn.LSTM(3072, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
64-
dropout=0.5)
65+
self.lstm_forward = nn.LSTM(3072, self.lstm_hidden_dim, num_layers=self.lstm_layers, bias=True,
66+
dropout=lstm_dropout)
6567

66-
self.lstm_backward = nn.LSTM(3072, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
67-
dropout=0.5)
68+
if self.bidirectional:
69+
self.lstm_backward = nn.LSTM(3072, self.lstm_hidden_dim, num_layers=self.lstm_layers, bias=True,
70+
dropout=lstm_dropout)
6871

69-
self.fc1 = nn.Linear(self.hidden_dim * self.directions, 100)
72+
self.fc1 = nn.Linear(self.lstm_hidden_dim * self.directions, fc2_dim)
7073
self.dropout = nn.Dropout(p=0.5)
71-
self.fc2 = nn.Linear(100, n_classes)
74+
self.fc2 = nn.Linear(fc2_dim, n_classes)
7275

7376
def init_hidden(self, input_length):
7477
# The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim)
75-
return (torch.zeros(self.lstm_layers * self.directions, input_length, self.hidden_dim).to(device),
76-
torch.zeros(self.lstm_layers * self.directions, input_length, self.hidden_dim).to(device))
78+
return (torch.zeros(self.lstm_layers * self.directions, input_length, self.lstm_hidden_dim).to(device),
79+
torch.zeros(self.lstm_layers * self.directions, input_length, self.lstm_hidden_dim).to(device))
7780

7881
def forward(self, x):
7982
"""
@@ -104,8 +107,11 @@ def forward(self, x):
104107
hidden = self.init_hidden(current_batch_size)
105108

106109
outs1, _ = self.lstm_forward(features, hidden)
107-
outs2, _ = self.lstm_backward(features.flip(0), hidden)
108-
outs = outs1.add(outs2.flip(0))
110+
if self.bidirectional:
111+
outs2, _ = self.lstm_backward(features.flip(0), hidden)
112+
outs = outs1.add(outs2.flip(0))
113+
else:
114+
outs = outs1
109115

110116
# Decode the hidden state of the last time step
111117
outs = self.fc1(outs)

train.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ def parse_args():
7474
return args
7575

7676

77-
def create_dataloader(args: Namespace, verbose: bool = False) -> Dict[str, DataLoader]:
77+
def create_dataloader(data_path, target_size, train_val_split, batch_size,
78+
verbose: bool = False) -> Dict[str, DataLoader]:
7879
# Data augmentation and normalization for training
7980
# Just normalization for validation
80-
width, height = args.target_size
81+
width, height = target_size
8182
data_transforms = {
8283
'train': transforms.Compose([
8384
transforms.Resize((width, height)),
@@ -94,23 +95,24 @@ def create_dataloader(args: Namespace, verbose: bool = False) -> Dict[str, DataL
9495
}
9596

9697
# Load dataset
97-
dataset = CAR(args.data, transform=data_transforms, train_val_split=args.train_val_split, verbose=verbose)
98+
dataset = CAR(data_path, transform=data_transforms, train_val_split=train_val_split, verbose=verbose)
9899
if verbose:
99100
print(dataset)
100101

101102
# Create training and validation dataloaders
103+
loader_names = ['train', 'test']
104+
if train_val_split < 1.0:
105+
loader_names.append('val')
102106
dataloaders_dict = {
103107
x: DataLoader(dataset.subsets[x],
104-
batch_size=args.batch_size,
108+
batch_size=batch_size,
105109
shuffle=True,
106110
num_workers=4
107-
) for x in ['train', 'test', 'val']
111+
) for x in loader_names
108112
}
109113
return dataloaders_dict
110114

111115

112-
def build_model(n_classes: int, seq_length: int, batch_size: int) -> StringNet:
113-
return StringNet(n_classes, seq_length, batch_size)
114116

115117

116118
def loss_func():
@@ -173,38 +175,57 @@ def calc_acc(output, targets: List[str]):
173175
return acc
174176

175177

176-
def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
178+
def run(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
177179
set_seed(seed)
178-
179-
# Load dataset and create data loaders
180-
dataloaders = create_dataloader(args, verbose)
181-
182-
seq_length = 15
180+
timer = Timer()
181+
timer.start()
182+
seq_length = 15 # TODO: make this a parameter
183183

184184
if args.load_path is not None and Path(args.load_path).is_file():
185185
print("Loading model weights from: " + args.load_path)
186186
model = torch.load(args.load_path)
187187
else:
188-
model = build_model(11, seq_length, args.batch_size).to(device)
188+
model = StringNet(11, seq_length, args.batch_size).to(device)
189+
190+
# Load dataset and create data loaders
191+
dataloaders = create_dataloader(args.data, target_size=args.target_size,
192+
train_val_split=args.train_val_split,
193+
batch_size=args.batch_size, verbose=verbose)
194+
195+
# Train
196+
history = train(model, dataloaders['train'], dataloaders.get('val', None), lr=args.lr, epochs=args.epochs,
197+
log_path=args.log, save_path=args.save_path, verbose=verbose)
198+
199+
# Test
200+
test_results = test(model, dataloaders['test'], verbose)
201+
print("Test | " + format_status_line(test_results))
189202

190-
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
203+
timer.stop()
204+
test_results["total_training_time"] = timer.total()
205+
return history, test_results
206+
207+
208+
def train(model: StringNet, train_data, val_data=None, lr=1e-4, epochs=100,
209+
log_path: str = None, save_path: str = None,
210+
verbose: bool = False) -> List[Dict[str, Any]]:
211+
# TODO: Early stopping
212+
213+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
191214
floss = loss_func()
192215

193-
# Train here
194216
history = []
195-
phase = 'train'
196217
batch_timer = Timer()
197218
epoch_timer = Timer()
198-
total_batches = len(dataloaders[phase])
199-
for epoch in range(args.epochs):
219+
total_batches = len(train_data)
220+
for epoch in range(epochs):
200221
model.train()
201222
epoch_timer.start()
202223
batch_timer.reset()
203224

204225
total_loss = num_samples = total_distance = total_accuracy = 0
205226
dummy_images = dummy_batch_targets = None
206227

207-
for batch_num, (image, str_targets) in enumerate(dataloaders[phase]):
228+
for batch_num, (image, str_targets) in enumerate(train_data):
208229
batch_timer.start()
209230
# string to individual ints
210231
int_targets = [[int(c) for c in gt] for gt in str_targets]
@@ -242,7 +263,11 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
242263
print("Train examples: ")
243264
print(model(dummy_images).argmax(2)[:, :10], dummy_batch_targets[:10])
244265

245-
val_results = test(model, dataloaders['val'], verbose)
266+
if val_data is not None:
267+
val_results = test(model, val_data, verbose)
268+
else:
269+
val_results = {}
270+
246271
history_item = {}
247272
history_item['epoch'] = epoch + 1
248273
history_item['avg_dist'] = total_distance / num_samples
@@ -255,18 +280,13 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
255280
status_line = format_status_line(history_item)
256281
print(status_line)
257282

258-
write_to_csv(history_item, args.log, write_header=epoch == 0, append=epoch != 0)
283+
if log_path is not None:
284+
write_to_csv(history_item, log_path, write_header=epoch == 0, append=epoch != 0)
259285

260-
if args.save_path is not None:
261-
torch.save(model, args.save_path)
286+
if save_path is not None:
287+
torch.save(model, save_path)
262288

263-
# Test here
264-
test_results = test(model, dataloaders['test'], verbose)
265-
status_line = format_status_line(test_results)
266-
print("Test | " + status_line)
267-
test_results["total_training_time"] = epoch_timer.total()
268-
torch.save(model.state_dict(), './model.pth')
269-
return history, test_results
289+
return history
270290

271291

272292
def test(model: nn.Module, dataloader: DataLoader, verbose: bool = False) -> Dict[str, Any]:
@@ -308,10 +328,10 @@ def test(model: nn.Module, dataloader: DataLoader, verbose: bool = False) -> Dic
308328
if __name__ == "__main__":
309329
args = parse_args()
310330
if len(args.seed) == 1:
311-
train(args, seed=args.seed[0], verbose=args.verbose)
331+
run(args, seed=args.seed[0], verbose=args.verbose)
312332
else:
313333
# Get the results for every seed
314-
results = [train(args, seed=seed, verbose=args.verbose) for seed in args.seed]
334+
results = [run(args, seed=seed, verbose=args.verbose) for seed in args.seed]
315335
results = [result[1] for result in results]
316336
# Create dictionary to get a mapping from metric_name -> array of results of that metric
317337
# e.g. { 'accuracy': [0.67, 0.68] }

0 commit comments

Comments
 (0)