Skip to content

Commit

Permalink
cnn change, soon will be model update
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenGarnets committed Jan 15, 2020
1 parent 42051f5 commit 3ab2616
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 29 deletions.
Binary file added infer.wav
Binary file not shown.
Binary file added infer2.wav
Binary file not shown.
Binary file modified net/__pycache__/model.cpython-37.pyc
Binary file not shown.
44 changes: 23 additions & 21 deletions net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,43 @@ def __init__(self):
nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.AvgPool1d(3))
nn.MaxPool1d(3,stride=3))
self.conv3 = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.AvgPool1d(3))
nn.MaxPool1d(3,stride=3))
self.conv4 = nn.Sequential(
nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.AvgPool1d(3))
nn.MaxPool1d(3,stride=3))
self.conv5 = nn.Sequential(
nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.AvgPool1d(3))
nn.MaxPool1d(3,stride=3))
self.conv6 = nn.Sequential(
nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(512),
nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.AvgPool1d(3))
nn.MaxPool1d(3,stride=3))
self.conv7 = nn.Sequential(
nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.AvgPool1d(3))
nn.MaxPool1d(3,stride=3))
self.conv8 = nn.Sequential(
nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.AvgPool1d(2),
nn.Dropout(0.5))

self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 4)

self.fc = nn.Linear(1024, 4)

self.LSTM = nn.LSTM(input_size = 4, hidden_size = 4, bidirectional=True)
#self.LSTM = nn.LSTM(input_size = 4, hidden_size = 2, bidirectional=True)
#self.fc3 = nn.Linear(hidden_size*2,output_size)
#self.output = F.softmax(linear(output),1)


def forward(self, x, batch):
Expand All @@ -94,18 +94,20 @@ def forward(self, x, batch):
out = out.reshape(batch,out.size(1)* out.size(2))
#print(out.shape)

out = self.fc1(out)
out = self.fc(out)
#print(out.shape)
out = self.fc2(out)
#out = self.fc2(out)
#print(out.shape)

#out = out.reshape(1,batch,out.size(1))
#print(out.shape)
#out, hidden = self.LSTM(out)
#print(out.shape)
#out = out.squeeze()
# out = out.reshape(1,batch,out.size(1))
# print(out.shape)
# out, hidden = self.LSTM(out)
# print(out.shape)
# out = out.squeeze()
# print(out.shape)

return out
#return F.softmax(out,1)
'''
class VoltexLSTM
Expand Down
17 changes: 9 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import os

def infer(model, filename) :
def infer(model, device, batch, filename, savename) :

# Training End, infer #
input = KshDataset.music_load(filename + '/nofx.ogg')
Expand Down Expand Up @@ -66,12 +66,12 @@ def infer(model, filename) :

song = mp.Audio(filename = (filename + "nofx.ogg"), note_timestamp = note_time_Stamp_output, fx_timestamp = fx_time_Stamp_output)
song.synthesize(diff='ka')
song.save(filename = "infer.wav")
song.save(filename = savename)

def main():

model = voltexNet()
#model.load_state_dict(torch.load("./model/model_bestAcc_.pth"))
model.load_state_dict(torch.load("./model_99_.pth"))
#print ("load model")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand All @@ -81,7 +81,7 @@ def main():
#input = torch.rand(128,3,80,15)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.2, momentum=0.9)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)

# kshDataset에서는 (곡의 길이 * 25 = batch) * (sample rate * 0.04)의 tensor가 넘어옴
Expand All @@ -97,12 +97,12 @@ def main():
filenames = os.listdir(dirname)
valnames = os.listdir(valname)

batch = 512
batch = 256
song_index = 0
best_Acc = 0

epoch_loss = 0.0

'''
for epoch in range(0,800) :
epoch_loss = 0.0
train_loss = 0.0
Expand Down Expand Up @@ -213,8 +213,9 @@ def main():
break
torch.save(model.state_dict(), "./model/model_"+ str(epoch)+ "_.pth")

infer(mode, "./data_test/songs/badapple_nomico_alreco")
'''
infer(model, device, batch, "./data_test/songs/smooooch_kn/","infer.wav")
infer(model, device, batch, "./data_test/songs/dynasty_yooh/","infer2.wav")



Expand Down

0 comments on commit 3ab2616

Please sign in to comment.