Skip to content

Commit

Permalink
loss function change
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenGarnets committed Jan 29, 2020
1 parent 14463de commit 42bb9e6
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 10 deletions.
Binary file modified __pycache__/dataset.cpython-36.pyc
Binary file not shown.
10 changes: 5 additions & 5 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def timeStamp(filename, term) :
return_timestamp = []
for (note, fx) in zip(note_time_Stamp, fx_time_Stamp) :
if note == 1 or fx == 1 :
return_timestamp.append([1])
return_timestamp.append(1)
else :
return_timestamp.append([0])
return_timestamp.append(0)

# test log
index = 0
Expand All @@ -255,9 +255,9 @@ def timeStamp(filename, term) :
return return_timestamp

if __name__ == "__main__":
filenames = os.listdir("./test_ogg/")
KshDataset.music_cache_make(filenames)
#KshDataset.timeStamp("./data/songs/rootsphere_lastnote/exh.ksh", 3400)
#filenames = os.listdir("./test_ogg/")
#$KshDataset.music_cache_make(filenames)
KshDataset.timeStamp("./data/songs/rootsphere_lastnote/exh.ksh", 3400)
#y, sr = librosa.load("./data/songs/rootsphere_lastnote/nofx.ogg", sr=44100)
#KshDataset.timeStamp("./data/songs/rootsphere_lastnote/exh.ksh", y.shape[0])
#print(y.shape[0]//441)
4 changes: 2 additions & 2 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def infer(model, device, batch, filename, savename) :
#print(fx_time_Stamp_output)
#print(fx_time_Stamp_output)

song = mp.Audio(filename = ("./test_ogg/bgm.ogg"), note_timestamp = note_time_Stamp_output, fx_timestamp = fx_time_Stamp_output)
song = mp.Audio(filename = ("./data_test/songs/badapple_nomico_alreco/nofx.ogg"), note_timestamp = note_time_Stamp_output, fx_timestamp = fx_time_Stamp_output)
song.synthesize(diff='ka')
song.save(filename = savename)

Expand All @@ -98,7 +98,7 @@ def main():
#infer(model, device, batch, "./cache/albida.npy","./test_Output/infer.wav")
#infer(model, device, batch, "./test_ogg/nofx.npy","./test_Output/infer2.wav")
#infer(model, device, batch, "./Asset/KANA-BOON - Silhouette.ogg","./test_Output/infer3.wav")
infer(model, device, batch, "./test_ogg/bgm.npy","./test_Output/infer4.wav")
infer(model, device, batch, "./cache/badapple_nomico_alreco.npy","./test_Output/infer3.wav")



Expand Down
Binary file modified net/__pycache__/model.cpython-36.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self):

self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 2)
self.fc3 = nn.Linear(256, 1)

#self.lstm = nn.LSTM()

Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main():
model.to(device)
#input = torch.rand(128,3,80,15)

criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)

Expand Down Expand Up @@ -125,7 +125,7 @@ def main():
__ = 1

input = input.to(device, dtype=torch.float)
target = target.to(device, dtype=torch.int64)
target = target.to(device, dtype=torch.float)

#print(input.shape)
#print(target.shape)
Expand Down

0 comments on commit 42bb9e6

Please sign in to comment.