Skip to content

Commit

Permalink
modified the total epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
liouxy committed Nov 25, 2017
1 parent 7247680 commit d86f470
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--batchSize', type=int, default=125, help='training batch size')
parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for')
parser.add_argument('--nEpochs', type=int, default=600, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01')
parser.add_argument('--cuda', action='store_true', help='use cuda?')
parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
Expand All @@ -24,7 +24,7 @@
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
parser.add_argument("--step", type=int, default=100, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
parser.add_argument("--step", type=int, default=150, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
opt = parser.parse_args()

def main():
Expand Down Expand Up @@ -118,7 +118,7 @@ def train(training_data_loader, optimizer, model, criterion, epoch):
loss.data[0]))

def save_checkpoint(model, epoch):
model_out_path = "model/" + "model_epoch_{}.pth".format(epoch)
model_out_path = "model/model" + "model_epoch_{}.pth".format(epoch)
state = {"epoch": epoch, "model": model}

if not os.path.exists("model/"):
Expand Down

0 comments on commit d86f470

Please sign in to comment.