@@ -74,10 +74,11 @@ def parse_args():
74
74
return args
75
75
76
76
77
- def create_dataloader (args : Namespace , verbose : bool = False ) -> Dict [str , DataLoader ]:
77
+ def create_dataloader (data_path , target_size , train_val_split , batch_size ,
78
+ verbose : bool = False ) -> Dict [str , DataLoader ]:
78
79
# Data augmentation and normalization for training
79
80
# Just normalization for validation
80
- width , height = args . target_size
81
+ width , height = target_size
81
82
data_transforms = {
82
83
'train' : transforms .Compose ([
83
84
transforms .Resize ((width , height )),
@@ -94,23 +95,24 @@ def create_dataloader(args: Namespace, verbose: bool = False) -> Dict[str, DataL
94
95
}
95
96
96
97
# Load dataset
97
- dataset = CAR (args . data , transform = data_transforms , train_val_split = args . train_val_split , verbose = verbose )
98
+ dataset = CAR (data_path , transform = data_transforms , train_val_split = train_val_split , verbose = verbose )
98
99
if verbose :
99
100
print (dataset )
100
101
101
102
# Create training and validation dataloaders
103
+ loader_names = ['train' , 'test' ]
104
+ if train_val_split < 1.0 :
105
+ loader_names .append ('val' )
102
106
dataloaders_dict = {
103
107
x : DataLoader (dataset .subsets [x ],
104
- batch_size = args . batch_size ,
108
+ batch_size = batch_size ,
105
109
shuffle = True ,
106
110
num_workers = 4
107
- ) for x in [ 'train' , 'test' , 'val' ]
111
+ ) for x in loader_names
108
112
}
109
113
return dataloaders_dict
110
114
111
115
112
- def build_model (n_classes : int , seq_length : int , batch_size : int ) -> StringNet :
113
- return StringNet (n_classes , seq_length , batch_size )
114
116
115
117
116
118
def loss_func ():
@@ -173,38 +175,57 @@ def calc_acc(output, targets: List[str]):
173
175
return acc
174
176
175
177
176
- def train (args : Namespace , seed : int = 0 , verbose : bool = False ) -> Tuple [List [Dict [str , Any ]], Dict [str , Any ]]:
178
+ def run (args : Namespace , seed : int = 0 , verbose : bool = False ) -> Tuple [List [Dict [str , Any ]], Dict [str , Any ]]:
177
179
set_seed (seed )
178
-
179
- # Load dataset and create data loaders
180
- dataloaders = create_dataloader (args , verbose )
181
-
182
- seq_length = 15
180
+ timer = Timer ()
181
+ timer .start ()
182
+ seq_length = 15 # TODO: make this a parameter
183
183
184
184
if args .load_path is not None and Path (args .load_path ).is_file ():
185
185
print ("Loading model weights from: " + args .load_path )
186
186
model = torch .load (args .load_path )
187
187
else :
188
- model = build_model (11 , seq_length , args .batch_size ).to (device )
188
+ model = StringNet (11 , seq_length , args .batch_size ).to (device )
189
+
190
+ # Load dataset and create data loaders
191
+ dataloaders = create_dataloader (args .data , target_size = args .target_size ,
192
+ train_val_split = args .train_val_split ,
193
+ batch_size = args .batch_size , verbose = verbose )
194
+
195
+ # Train
196
+ history = train (model , dataloaders ['train' ], dataloaders .get ('val' , None ), lr = args .lr , epochs = args .epochs ,
197
+ log_path = args .log , save_path = args .save_path , verbose = verbose )
198
+
199
+ # Test
200
+ test_results = test (model , dataloaders ['test' ], verbose )
201
+ print ("Test | " + format_status_line (test_results ))
189
202
190
- optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
203
+ timer .stop ()
204
+ test_results ["total_training_time" ] = timer .total ()
205
+ return history , test_results
206
+
207
+
208
+ def train (model : StringNet , train_data , val_data = None , lr = 1e-4 , epochs = 100 ,
209
+ log_path : str = None , save_path : str = None ,
210
+ verbose : bool = False ) -> List [Dict [str , Any ]]:
211
+ # TODO: Early stopping
212
+
213
+ optimizer = torch .optim .Adam (model .parameters (), lr = lr )
191
214
floss = loss_func ()
192
215
193
- # Train here
194
216
history = []
195
- phase = 'train'
196
217
batch_timer = Timer ()
197
218
epoch_timer = Timer ()
198
- total_batches = len (dataloaders [ phase ] )
199
- for epoch in range (args . epochs ):
219
+ total_batches = len (train_data )
220
+ for epoch in range (epochs ):
200
221
model .train ()
201
222
epoch_timer .start ()
202
223
batch_timer .reset ()
203
224
204
225
total_loss = num_samples = total_distance = total_accuracy = 0
205
226
dummy_images = dummy_batch_targets = None
206
227
207
- for batch_num , (image , str_targets ) in enumerate (dataloaders [ phase ] ):
228
+ for batch_num , (image , str_targets ) in enumerate (train_data ):
208
229
batch_timer .start ()
209
230
# string to individual ints
210
231
int_targets = [[int (c ) for c in gt ] for gt in str_targets ]
@@ -242,7 +263,11 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
242
263
print ("Train examples: " )
243
264
print (model (dummy_images ).argmax (2 )[:, :10 ], dummy_batch_targets [:10 ])
244
265
245
- val_results = test (model , dataloaders ['val' ], verbose )
266
+ if val_data is not None :
267
+ val_results = test (model , val_data , verbose )
268
+ else :
269
+ val_results = {}
270
+
246
271
history_item = {}
247
272
history_item ['epoch' ] = epoch + 1
248
273
history_item ['avg_dist' ] = total_distance / num_samples
@@ -255,18 +280,13 @@ def train(args: Namespace, seed: int = 0, verbose: bool = False) -> Tuple[List[D
255
280
status_line = format_status_line (history_item )
256
281
print (status_line )
257
282
258
- write_to_csv (history_item , args .log , write_header = epoch == 0 , append = epoch != 0 )
283
+ if log_path is not None :
284
+ write_to_csv (history_item , log_path , write_header = epoch == 0 , append = epoch != 0 )
259
285
260
- if args . save_path is not None :
261
- torch .save (model , args . save_path )
286
+ if save_path is not None :
287
+ torch .save (model , save_path )
262
288
263
- # Test here
264
- test_results = test (model , dataloaders ['test' ], verbose )
265
- status_line = format_status_line (test_results )
266
- print ("Test | " + status_line )
267
- test_results ["total_training_time" ] = epoch_timer .total ()
268
- torch .save (model .state_dict (), './model.pth' )
269
- return history , test_results
289
+ return history
270
290
271
291
272
292
def test (model : nn .Module , dataloader : DataLoader , verbose : bool = False ) -> Dict [str , Any ]:
@@ -308,10 +328,10 @@ def test(model: nn.Module, dataloader: DataLoader, verbose: bool = False) -> Dic
308
328
if __name__ == "__main__" :
309
329
args = parse_args ()
310
330
if len (args .seed ) == 1 :
311
- train (args , seed = args .seed [0 ], verbose = args .verbose )
331
+ run (args , seed = args .seed [0 ], verbose = args .verbose )
312
332
else :
313
333
# Get the results for every seed
314
- results = [train (args , seed = seed , verbose = args .verbose ) for seed in args .seed ]
334
+ results = [run (args , seed = seed , verbose = args .verbose ) for seed in args .seed ]
315
335
results = [result [1 ] for result in results ]
316
336
# Create dictionary to get a mapping from metric_name -> array of results of that metric
317
337
# e.g. { 'accuracy': [0.67, 0.68] }
0 commit comments