diff --git a/infer.wav b/infer.wav new file mode 100644 index 0000000..a836907 Binary files /dev/null and b/infer.wav differ diff --git a/infer2.wav b/infer2.wav new file mode 100644 index 0000000..8657086 Binary files /dev/null and b/infer2.wav differ diff --git a/net/__pycache__/model.cpython-37.pyc b/net/__pycache__/model.cpython-37.pyc index 60d14b6..9cebb84 100644 Binary files a/net/__pycache__/model.cpython-37.pyc and b/net/__pycache__/model.cpython-37.pyc differ diff --git a/net/model.py b/net/model.py index 46e20df..9b3759e 100644 --- a/net/model.py +++ b/net/model.py @@ -32,43 +32,43 @@ def __init__(self): nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(128), nn.ReLU(), - nn.AvgPool1d(3)) + nn.MaxPool1d(3,stride=3)) self.conv3 = nn.Sequential( nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(128), nn.ReLU(), - nn.AvgPool1d(3)) + nn.MaxPool1d(3,stride=3)) self.conv4 = nn.Sequential( nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(256), nn.ReLU(), - nn.AvgPool1d(3)) + nn.MaxPool1d(3,stride=3)) self.conv5 = nn.Sequential( nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(256), nn.ReLU(), - nn.AvgPool1d(3)) + nn.MaxPool1d(3,stride=3)) self.conv6 = nn.Sequential( - nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1), - nn.BatchNorm1d(512), + nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm1d(256), nn.ReLU(), - nn.AvgPool1d(3)) + nn.MaxPool1d(3,stride=3)) self.conv7 = nn.Sequential( - nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1), + nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(512), nn.ReLU(), - nn.AvgPool1d(3)) + nn.MaxPool1d(3,stride=3)) self.conv8 = nn.Sequential( nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm1d(512), nn.ReLU(), - nn.AvgPool1d(2), nn.Dropout(0.5)) - - self.fc1 = nn.Linear(512, 256) - self.fc2 = nn.Linear(256, 4) + + self.fc = nn.Linear(1024, 4) - self.LSTM = nn.LSTM(input_size = 4, hidden_size = 4, bidirectional=True) + #self.LSTM = nn.LSTM(input_size = 4, hidden_size = 2, bidirectional=True) + #self.fc3 = nn.Linear(hidden_size*2,output_size) + #self.output = F.softmax(linear(output),1) def forward(self, x, batch): @@ -94,18 +94,20 @@ def forward(self, x, batch): out = out.reshape(batch,out.size(1)* out.size(2)) #print(out.shape) - out = self.fc1(out) + out = self.fc(out) #print(out.shape) - out = self.fc2(out) + #out = self.fc2(out) #print(out.shape) - #out = out.reshape(1,batch,out.size(1)) - #print(out.shape) - #out, hidden = self.LSTM(out) - #print(out.shape) - #out = out.squeeze() + # out = out.reshape(1,batch,out.size(1)) + # print(out.shape) + # out, hidden = self.LSTM(out) + # print(out.shape) + # out = out.squeeze() + # print(out.shape) return out + #return F.softmax(out,1) ''' class VoltexLSTM diff --git a/train.py b/train.py index af9072f..2aaffd7 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ import os -def infer(model, filename) : +def infer(model, device, batch, filename, savename) : # Training End, infer # input = KshDataset.music_load(filename + '/nofx.ogg') @@ -66,12 +66,12 @@ def infer(model, filename) : song = mp.Audio(filename = (filename + "nofx.ogg"), note_timestamp = note_time_Stamp_output, fx_timestamp = fx_time_Stamp_output) song.synthesize(diff='ka') - song.save(filename = "infer.wav") + song.save(filename = savename) def main(): model = voltexNet() - #model.load_state_dict(torch.load("./model/model_bestAcc_.pth")) + model.load_state_dict(torch.load("./model_99_.pth")) #print ("load model") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -81,7 +81,7 @@ def main(): #input = torch.rand(128,3,80,15) criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=0.2, momentum=0.9) + optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1) # kshDataset에서는 (곡의 길이 * 25 = batch) * (sample rate * 0.04)의 tensor가 넘어옴 @@ -97,12 +97,12 @@ def main(): filenames = os.listdir(dirname) valnames = os.listdir(valname) - batch = 512 + batch = 256 song_index = 0 best_Acc = 0 epoch_loss = 0.0 - + ''' for epoch in range(0,800) : epoch_loss = 0.0 train_loss = 0.0 @@ -213,8 +213,9 @@ def main(): break torch.save(model.state_dict(), "./model/model_"+ str(epoch)+ "_.pth") - - infer(mode, "./data_test/songs/badapple_nomico_alreco") + ''' + infer(model, device, batch, "./data_test/songs/smooooch_kn/","infer.wav") + infer(model, device, batch, "./data_test/songs/dynasty_yooh/","infer2.wav")