|
19 | 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
20 | 20 | _ = parser.add_argument # define add_argument shortcut
|
21 | 21 | _('--data', type=str, default='./data/processed-data', help='location of the video data')
|
22 |
| -_('--model', type=str, default='model_01', help='type of auto-encoder') |
23 |
| -_('--size', type=int, default=(3, 6, 12), nargs='*', help='number and size of hidden layers', metavar='S') |
| 22 | +_('--model', type=str, default='CortexNet', help='type of auto-encoder') |
| 23 | +_('--mode', type=str, required=True, help='training mode [MatchNet|TempoNet]') |
| 24 | +_('--size', type=int, default=(3, 32, 64, 128, 256), nargs='*', help='number and size of hidden layers', metavar='S') |
24 | 25 | _('--spatial-size', type=int, default=(256, 256), nargs=2, help='frame cropping size', metavar=('H', 'W'))
|
25 | 26 | _('--lr', type=float, default=0.1, help='initial learning rate')
|
26 | 27 | _('--momentum', type=float, default=0.9, metavar='M', help='momentum')
|
27 | 28 | _('--weight-decay', type=float, default=1e-4, metavar='W', help='weight decay')
|
28 |
| -_('--mu', type=float, default=1, help='MSE multiplier', dest='mu', metavar='μ') |
29 |
| -_('--lambda', type=float, default=0.1, help='final CE stabiliser multiplier', dest='lambda_', metavar='λ') |
30 |
| -_('--pi', default='λ', help='periodical CE stabiliser multiplier', dest='pi', metavar='π') |
31 |
| -_('--epochs', type=int, default=6, help='upper epoch limit') |
| 29 | +_('--mu', type=float, default=1, help='matching MSE multiplier', dest='mu', metavar='μ') |
| 30 | +_('--tau', type=float, default=0.1, help='temporal CE multiplier', dest='tau', metavar='τ') |
| 31 | +_('--pi', default='τ', help='periodical CE multiplier', dest='pi', metavar='π') |
| 32 | +_('--epochs', type=int, default=10, help='upper epoch limit') |
32 | 33 | _('--batch-size', type=int, default=20, metavar='B', help='batch size')
|
33 |
| -_('--big-t', type=int, default=20, help='sequence length', metavar='T') |
| 34 | +_('--big-t', type=int, default=10, help='sequence length', metavar='T') |
34 | 35 | _('--seed', type=int, default=0, help='random seed')
|
35 |
| -_('--log-interval', type=int, default=200, metavar='N', help='report interval') |
| 36 | +_('--log-interval', type=int, default=10, metavar='N', help='report interval') |
36 | 37 | _('--save', type=str, default='model.pth.tar', help='path to save the final model')
|
37 | 38 | _('--cuda', action='store_true', help='use CUDA')
|
38 | 39 | _('--view', type=int, default=tuple(), help='samples to view at the end of every log-interval batches', metavar='V')
|
|
43 | 44 | args.size = tuple(args.size) # cast to tuple
|
44 | 45 | if args.lr_decay: args.lr_decay = tuple(args.lr_decay)
|
45 | 46 | if type(args.view) is int: args.view = (args.view,) # cast to tuple
|
46 |
| -args.pi = args.lambda_ if args.pi == 'λ' else float(args.pi) |
| 47 | +args.pi = args.tau if args.pi == 'τ' else float(args.pi) |
47 | 48 |
|
48 | 49 | # Print current options
|
49 | 50 | print('CLI arguments:', ' '.join(argv[1:]))
|
@@ -109,7 +110,7 @@ def main():
|
109 | 110 | # Build the model
|
110 | 111 | if args.model == 'model_01':
|
111 | 112 | from model.Model01 import Model01 as Model
|
112 |
| - elif args.model == 'model_02': |
| 113 | + elif args.model == 'model_02' or args.model == 'CortexNet': |
113 | 114 | from model.Model02 import Model02 as Model
|
114 | 115 | elif args.model == 'model_02_rg':
|
115 | 116 | from model.Model02 import Model02RG as Model
|
@@ -251,11 +252,11 @@ def compute_loss(x_, next_x, y_, state_, periodic=False):
|
251 | 252 | if from_past:
|
252 | 253 | mismatch = y[0] != from_past[1]
|
253 | 254 | ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state, periodic=True)
|
254 |
| - loss += mse_loss * args.mu + ce_loss[0] * args.lambda_ + ce_loss[1] * args.pi |
| 255 | + loss += mse_loss * args.mu + ce_loss[0] * args.tau + ce_loss[1] * args.pi |
255 | 256 | for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward
|
256 | 257 | mismatch = y[t + 1] != y[t]
|
257 | 258 | ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state)
|
258 |
| - loss += mse_loss * args.mu + ce_loss * args.lambda_ |
| 259 | + loss += mse_loss * args.mu + ce_loss * args.tau |
259 | 260 |
|
260 | 261 | # compute gradient and do SGD step
|
261 | 262 | model.zero_grad()
|
|
0 commit comments