Skip to content

Commit

Permalink
add cnn layer
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenGarnets committed Jan 11, 2020
1 parent 338a6b2 commit 7ba3316
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def infer(model, filename) :
song.synthesize(diff='ka')
song.save(filename = "infer.wav")

if __name__ == "__main__":

def main():
model = voltexNet()
#model.load_state_dict(torch.load("./model/model_bestAcc_.pth"))
#print ("load model")
Expand All @@ -81,7 +81,7 @@ def infer(model, filename) :
#input = torch.rand(128,3,80,15)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.SGD(model.parameters(), lr=0.2, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)

# kshDataset에서는 (곡의 길이 * 25 = batch) * (sample rate * 0.04)의 tensor가 넘어옴
Expand All @@ -97,7 +97,7 @@ def infer(model, filename) :
filenames = os.listdir(dirname)
valnames = os.listdir(valname)

batch = 256
batch = 512
song_index = 0
best_Acc = 0

Expand All @@ -115,6 +115,7 @@ def infer(model, filename) :
#print(full_filename)

input = KshDataset.music_load(full_filename + '/nofx.ogg')
#print(input.shape)
input = input.reshape(input.shape[0], 1, -1)
target = KshDataset.timeStamp(full_filename + '/exh.ksh', input.shape[0])
try :
Expand All @@ -136,6 +137,7 @@ def infer(model, filename) :

pred = model(input[i:i+batch],batch)


loss = criterion(pred.squeeze(), target[i:i+batch].squeeze())
loss.backward()
optimizer.step()
Expand Down Expand Up @@ -200,7 +202,7 @@ def infer(model, filename) :
if acc != 0 :
acc = acc / (len(valnames) - noneScore)

print("epoch : " + str(epoch) + "\ttrain_loss : " + str(epoch_loss/index) + "\tacc : " + str(acc) + "\n")
print("epoch : " + str(epoch) + "\ttrain_loss : " + str(epoch_loss/index) + "\ttest_acc : " + str(acc) + "\n")

if acc > best_Acc :
best_Acc = acc
Expand All @@ -215,4 +217,7 @@ def infer(model, filename) :
infer(mode, "./data_test/songs/badapple_nomico_alreco")



if __name__ == "__main__":
main()

0 comments on commit 7ba3316

Please sign in to comment.