From d86f470abcbff794ec3321921639d2beea7240b8 Mon Sep 17 00:00:00 2001 From: xyliu Date: Sat, 25 Nov 2017 21:53:39 +0800 Subject: [PATCH] modified the total epochs --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 50264d2..1902b76 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ parser = argparse.ArgumentParser(description='PyTorch Super Res Example') parser.add_argument('--batchSize', type=int, default=125, help='training batch size') parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size') -parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for') +parser.add_argument('--nEpochs', type=int, default=600, help='number of epochs to train for') parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01') parser.add_argument('--cuda', action='store_true', help='use cuda?') parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') @@ -24,7 +24,7 @@ parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") -parser.add_argument("--step", type=int, default=100, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500") +parser.add_argument("--step", type=int, default=150, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500") opt = parser.parse_args() def main(): @@ -118,7 +118,7 @@ def train(training_data_loader, optimizer, model, criterion, epoch): loss.data[0])) def save_checkpoint(model, epoch): - model_out_path = "model/" + "model_epoch_{}.pth".format(epoch) + model_out_path = "model/model" + "model_epoch_{}.pth".format(epoch) state = {"epoch": epoch, "model": model} if not os.path.exists("model/"):