From 7ba3316d070965d2b916019e6cf667428c949dee Mon Sep 17 00:00:00 2001 From: Garnet <38901503+GreenGarnets@users.noreply.github.com> Date: Sun, 12 Jan 2020 02:18:24 +0900 Subject: [PATCH] add cnn layer --- train.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 2f54407..af9072f 100644 --- a/train.py +++ b/train.py @@ -68,8 +68,8 @@ def infer(model, filename) : song.synthesize(diff='ka') song.save(filename = "infer.wav") -if __name__ == "__main__": - +def main(): + model = voltexNet() #model.load_state_dict(torch.load("./model/model_bestAcc_.pth")) #print ("load model") @@ -81,7 +81,7 @@ def infer(model, filename) : #input = torch.rand(128,3,80,15) criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + optimizer = optim.SGD(model.parameters(), lr=0.2, momentum=0.9) scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1) # kshDataset에서는 (곡의 길이 * 25 = batch) * (sample rate * 0.04)의 tensor가 넘어옴 @@ -97,7 +97,7 @@ def infer(model, filename) : filenames = os.listdir(dirname) valnames = os.listdir(valname) - batch = 256 + batch = 512 song_index = 0 best_Acc = 0 @@ -115,6 +115,7 @@ def infer(model, filename) : #print(full_filename) input = KshDataset.music_load(full_filename + '/nofx.ogg') + #print(input.shape) input = input.reshape(input.shape[0], 1, -1) target = KshDataset.timeStamp(full_filename + '/exh.ksh', input.shape[0]) try : @@ -136,6 +137,7 @@ def infer(model, filename) : pred = model(input[i:i+batch],batch) + loss = criterion(pred.squeeze(), target[i:i+batch].squeeze()) loss.backward() optimizer.step() @@ -200,7 +202,7 @@ def infer(model, filename) : if acc != 0 : acc = acc / (len(valnames) - noneScore) - print("epoch : " + str(epoch) + "\ttrain_loss : " + str(epoch_loss/index) + "\tacc : " + str(acc) + "\n") + print("epoch : " + str(epoch) + "\ttrain_loss : " + str(epoch_loss/index) + "\ttest_acc : " + str(acc) + "\n") if acc > best_Acc : best_Acc = acc @@ -215,4 +217,7 @@ def infer(model, filename) : infer(mode, "./data_test/songs/badapple_nomico_alreco") + +if __name__ == "__main__": + main() \ No newline at end of file