Skip to content

Commit

Permalink
dataset cache & add STFT
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenGarnets committed Jan 24, 2020
1 parent 9a66afe commit 5359093
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 21 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/cache
/data
/data_test
/model
Binary file modified __pycache__/dataset.cpython-36.pyc
Binary file not shown.
41 changes: 38 additions & 3 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import librosa
import librosa.display
import tqdm
from tqdm import tqdm
from scipy.fftpack import fft
from scipy import signal

Expand Down Expand Up @@ -81,13 +81,46 @@ def targetToTensor(self, target):


def music_load(filename) :
'''
y, sr = librosa.load(filename, sr=44100)
y_ = np.zeros(int(44100*0.04)-int(len(y)%int(44100*0.04)))
y = np.hstack([y,y_])
y = np.reshape(y, (-1, int(sr/100)*4 ))
y = torch.from_numpy(y)
print(a.shape)
'''
y = np.load(filename)
y = torch.from_numpy(y)

return y

def music_cache_make(filelist) :
#f = codecs.open("music_cache_data","w")

for filename in tqdm(filelist):
y, sr = librosa.load("./data_test/songs/" + filename+"/nofx.ogg", sr=44100)
y_ = np.zeros(int(44100*0.04)-int(len(y)%int(44100*0.04)))
y = np.hstack([y,y_])

y = np.reshape(y, (-1, int(sr/100)*4 ))
#print(y.shape)

a = []
for i in range(0, y.shape[0]) :
y1 = np.abs(librosa.stft(y[i], n_fft = 1764, hop_length=2048, win_length = 441))
y2 = np.abs(librosa.stft(y[i], n_fft = 1764, hop_length=2048, win_length = 882))
y3 = np.abs(librosa.stft(y[i], n_fft = 1764, hop_length=2048, win_length = 1764))

a.append(np.array([y1,y2,y3]).tolist())
#if i == 200 : print(np.array([y1,y2,y3]))

a = np.array(a)
#print(a.shape)
np.save("./cache/"+filename,a)

return y

def timeStamp(filename, term) :
Expand Down Expand Up @@ -220,6 +253,8 @@ def timeStamp(filename, term) :
return return_timestamp

if __name__ == "__main__":
y, sr = librosa.load("./data/songs/rootsphere_lastnote/nofx.ogg", sr=44100)
KshDataset.timeStamp("./data/songs/rootsphere_lastnote/exh.ksh", y.shape[0])
filenames = os.listdir("./data_test/songs/")
KshDataset.music_cache_make(filenames)
#y, sr = librosa.load("./data/songs/rootsphere_lastnote/nofx.ogg", sr=44100)
#KshDataset.timeStamp("./data/songs/rootsphere_lastnote/exh.ksh", y.shape[0])
#print(y.shape[0]//441)
Binary file modified net/__pycache__/model.cpython-36.pyc
Binary file not shown.
17 changes: 6 additions & 11 deletions net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self):
super(voltexNet, self).__init__()

self.conv1 = nn.Sequential(
nn.Conv1d(1, 128, kernel_size=3, stride=1, padding=1),
nn.Conv1d(3, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(128),
nn.ReLU())
self.conv2 = nn.Sequential(
Expand Down Expand Up @@ -39,24 +39,20 @@ def __init__(self):
nn.ReLU(),
nn.MaxPool1d(3,stride=3))
self.conv7 = nn.Sequential(
nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(512),
nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(3,stride=3))
self.conv8 = nn.Sequential(
nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5))

self.fc = nn.Linear(1024, 4)
self.fc = nn.Linear(256, 4)

#self.LSTM = nn.LSTM(input_size = 4, hidden_size = 2, bidirectional=True)
#self.fc3 = nn.Linear(hidden_size*2,output_size)
#self.output = F.softmax(linear(output),1)


def forward(self, x):
x = x.squeeze()

out = self.conv1(x)
out = self.conv2(out)
Expand All @@ -65,7 +61,6 @@ def forward(self, x):
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
out = self.conv8(out)

out = out.reshape(out.size(0), out.size(1)* out.size(2))
#print(out.shape)
Expand All @@ -83,7 +78,7 @@ def forward(self, x):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = voltexNet().to(device)

summary(model, [(1, 1764)])
summary(model, [(3, 883, 1)])

'''
class VoltexLSTM
Expand Down
13 changes: 6 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def infer(model, device, batch, filename, savename) :
def main():

model = voltexNet()
model.load_state_dict(torch.load("./model_99_.pth"))
#model.load_state_dict(torch.load("./model_99_.pth"))
#print ("load model")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand All @@ -97,7 +97,7 @@ def main():
filenames = os.listdir(dirname)
valnames = os.listdir(valname)

batch = 256
batch = 128
song_index = 0
best_Acc = 0

Expand All @@ -114,9 +114,9 @@ def main():
full_filename = os.path.join(dirname, filename)
#print(full_filename)

input = KshDataset.music_load(full_filename + '/nofx.ogg')
input = KshDataset.music_load("./cache/" + filename+'.npy')
#print(input.shape)
input = input.reshape(input.shape[0], 1, -1)
#input = input.reshape(input.shape[0], 1, -1)
target = KshDataset.timeStamp(full_filename + '/exh.ksh', input.shape[0])
try :
if target == None:
Expand Down Expand Up @@ -148,7 +148,7 @@ def main():
optimizer.zero_grad()
tmp_batch = input.shape[0] - i
if tmp_batch > 1 :
pred = model(input[i:i+tmp_batch-1],tmp_batch-1)
pred = model(input[i:i+tmp_batch-1])

loss = criterion(pred.squeeze(), target[i:i+tmp_batch-1].squeeze())
loss.backward()
Expand All @@ -165,8 +165,7 @@ def main():
#model.to(torch.device("cpu"))
for filename in tqdm(valnames):
full_filename = os.path.join(valname, filename)
input = KshDataset.music_load(full_filename + '/nofx.ogg')
input = input.reshape(input.shape[0], 1, -1)
input = KshDataset.music_load("./cache/" + filename+'.npy')
input = input.to(device, dtype=torch.float)
target = KshDataset.timeStamp(full_filename + '/exh.ksh', input.shape[0])

Expand Down

0 comments on commit 5359093

Please sign in to comment.