Skip to content

Commit d94f2ec

Browse files
committed
Uniform notation to paper
1 parent 27839ba commit d94f2ec

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

main.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,21 @@
1919
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
2020
_ = parser.add_argument # define add_argument shortcut
2121
_('--data', type=str, default='./data/processed-data', help='location of the video data')
22-
_('--model', type=str, default='model_01', help='type of auto-encoder')
23-
_('--size', type=int, default=(3, 6, 12), nargs='*', help='number and size of hidden layers', metavar='S')
22+
_('--model', type=str, default='CortexNet', help='type of auto-encoder')
23+
_('--mode', type=str, required=True, help='training mode [MatchNet|TempoNet]')
24+
_('--size', type=int, default=(3, 32, 64, 128, 256), nargs='*', help='number and size of hidden layers', metavar='S')
2425
_('--spatial-size', type=int, default=(256, 256), nargs=2, help='frame cropping size', metavar=('H', 'W'))
2526
_('--lr', type=float, default=0.1, help='initial learning rate')
2627
_('--momentum', type=float, default=0.9, metavar='M', help='momentum')
2728
_('--weight-decay', type=float, default=1e-4, metavar='W', help='weight decay')
28-
_('--mu', type=float, default=1, help='MSE multiplier', dest='mu', metavar='μ')
29-
_('--lambda', type=float, default=0.1, help='final CE stabiliser multiplier', dest='lambda_', metavar='λ')
30-
_('--pi', default='λ', help='periodical CE stabiliser multiplier', dest='pi', metavar='π')
31-
_('--epochs', type=int, default=6, help='upper epoch limit')
29+
_('--mu', type=float, default=1, help='matching MSE multiplier', dest='mu', metavar='μ')
30+
_('--tau', type=float, default=0.1, help='temporal CE multiplier', dest='tau', metavar='τ')
31+
_('--pi', default='τ', help='periodical CE multiplier', dest='pi', metavar='π')
32+
_('--epochs', type=int, default=10, help='upper epoch limit')
3233
_('--batch-size', type=int, default=20, metavar='B', help='batch size')
33-
_('--big-t', type=int, default=20, help='sequence length', metavar='T')
34+
_('--big-t', type=int, default=10, help='sequence length', metavar='T')
3435
_('--seed', type=int, default=0, help='random seed')
35-
_('--log-interval', type=int, default=200, metavar='N', help='report interval')
36+
_('--log-interval', type=int, default=10, metavar='N', help='report interval')
3637
_('--save', type=str, default='model.pth.tar', help='path to save the final model')
3738
_('--cuda', action='store_true', help='use CUDA')
3839
_('--view', type=int, default=tuple(), help='samples to view at the end of every log-interval batches', metavar='V')
@@ -43,7 +44,7 @@
4344
args.size = tuple(args.size) # cast to tuple
4445
if args.lr_decay: args.lr_decay = tuple(args.lr_decay)
4546
if type(args.view) is int: args.view = (args.view,) # cast to tuple
46-
args.pi = args.lambda_ if args.pi == 'λ' else float(args.pi)
47+
args.pi = args.tau if args.pi == 'τ' else float(args.pi)
4748

4849
# Print current options
4950
print('CLI arguments:', ' '.join(argv[1:]))
@@ -109,7 +110,7 @@ def main():
109110
# Build the model
110111
if args.model == 'model_01':
111112
from model.Model01 import Model01 as Model
112-
elif args.model == 'model_02':
113+
elif args.model == 'model_02' or args.model == 'CortexNet':
113114
from model.Model02 import Model02 as Model
114115
elif args.model == 'model_02_rg':
115116
from model.Model02 import Model02RG as Model
@@ -251,11 +252,11 @@ def compute_loss(x_, next_x, y_, state_, periodic=False):
251252
if from_past:
252253
mismatch = y[0] != from_past[1]
253254
ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state, periodic=True)
254-
loss += mse_loss * args.mu + ce_loss[0] * args.lambda_ + ce_loss[1] * args.pi
255+
loss += mse_loss * args.mu + ce_loss[0] * args.tau + ce_loss[1] * args.pi
255256
for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward
256257
mismatch = y[t + 1] != y[t]
257258
ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state)
258-
loss += mse_loss * args.mu + ce_loss * args.lambda_
259+
loss += mse_loss * args.mu + ce_loss * args.tau
259260

260261
# compute gradient and do SGD step
261262
model.zero_grad()

0 commit comments

Comments
 (0)