Skip to content

Commit

Permalink
import torchsummary
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenGarnets committed Jan 17, 2020
1 parent 32500cf commit fd99fe5
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 35 deletions.
Binary file added Asset/modelLog.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added __pycache__/dataset.cpython-36.pyc
Binary file not shown.
Binary file added __pycache__/music_processer.cpython-36.pyc
Binary file not shown.
Binary file modified net/__pycache__/model.cpython-36.pyc
Binary file not shown.
42 changes: 10 additions & 32 deletions net/model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# note model = DDC Step Selection + Step Placement
# nove model = Step Placement Score => FCLayer이 곧 Nove의 위치를 나타냄
# DDC Step Selection

#torch.Size([1024, 1, 1764])
#torch.Size([1024, 128, 1764])
#torch.Size([1024, 128, 441])
#torch.Size([1024, 128, 110])
#torch.Size([1024, 256, 27])
#torch.Size([1024, 256, 6])
#torch.Size([1024, 512, 1])
#torch.Size([1024, 512, 1])
#torch.Size([1024, 512])
#torch.Size([1024, 256])
#torch.Size([1024, 4])
from torchsummary import summary

class voltexNet(nn.Module):

Expand Down Expand Up @@ -71,43 +56,36 @@ def __init__(self):
#self.output = F.softmax(linear(output),1)


def forward(self, x, batch):
def forward(self, x):
batch = x.size(0)

#print(x.shape)
out = self.conv1(x)
#print(out.shape)
out = self.conv2(out)
#print(out.shape)
out = self.conv3(out)
#print(out.shape)
out = self.conv4(out)
#print(out.shape)
out = self.conv5(out)
#print(out.shape)
out = self.conv6(out)
#print(out.shape)
out = self.conv7(out)
#print(out.shape)
out = self.conv8(out)
#print(out.shape)

out = out.reshape(batch,out.size(1)* out.size(2))
#print(out.shape)

out = self.fc(out)
#print(out.shape)
#out = self.fc2(out)
#print(out.shape)

# out = out.reshape(1,batch,out.size(1))
# print(out.shape)
# out, hidden = self.LSTM(out)
# print(out.shape)
# out = out.squeeze()
# print(out.shape)

return out
#return F.softmax(out,1)

if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = voltexNet().to(device)

summary(model, [(1, 1764)])

'''
class VoltexLSTM
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def infer(model, device, batch, filename, savename) :
def main():

model = voltexNet()
model.load_state_dict(torch.load("./model_99_.pth"))
#model.load_state_dict(torch.load("./model_99_.pth"))
#print ("load model")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -102,7 +102,7 @@ def main():
best_Acc = 0

epoch_loss = 0.0
'''

for epoch in range(0,800) :
epoch_loss = 0.0
train_loss = 0.0
Expand Down Expand Up @@ -213,7 +213,7 @@ def main():
break

torch.save(model.state_dict(), "./model/model_"+ str(epoch)+ "_.pth")
'''

infer(model, device, batch, "./data_test/songs/smooooch_kn/","infer.wav")
infer(model, device, batch, "./data_test/songs/dynasty_yooh/","infer2.wav")

Expand Down

0 comments on commit fd99fe5

Please sign in to comment.