-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathget_demo.py
119 lines (106 loc) · 4.04 KB
/
get_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
sys.path.append("../")
import torch
import torch.nn as nn
import numpy as np
from model.pose_generator_norm import Generator#input 50,1,1600
from dataset.small_dataset import DanceDataset #audio input 50*1*1600
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.utils import save_image
import os
import numpy as np
import math
import itertools
import time
import datetime
from matplotlib import pyplot as plt
#import cv2
from dataset.output_helper import save_2_batch_images
import argparse
from scipy.io.wavfile import write
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="",
metavar="FILE",
help="path to pth file",
type=str,
)
parser.add_argument("--count", type=int, default=50)
parser.add_argument(
"--output",
default="'/mnt/external4/output_demo'",
metavar="FILE",
help="path to output",
type=str,
)
args = parser.parse_args()
file_path=args.model
counter=args.count
output_dir=args.output
Tensor = torch.cuda.FloatTensor
generator = Generator(1)
generator.eval()
generator.load_state_dict(torch.load(file_path))
generator.cuda()
data=DanceDataset("diff")
dataloader = torch.utils.data.DataLoader(data,
batch_size=1,
shuffle=False,
num_workers=8,
pin_memory=False)
criterion_pixelwise = torch.nn.L1Loss()
count = 0
total_loss=0.0
img_orig = np.ones((360,640,3), np.uint8) * 255
for i, (x,target) in enumerate(dataloader):
audio_out=x.view(-1) #80000
scaled=np.int16(audio_out)
# while True:
# try:
# os.mkdir(output_dir+'/audio')
# break
# except FileExistsError as e:
# # if e.errno != os.errno.EEXIST:
# # raise
# # time.sleep might help here
# pass
audio = Variable(x.type(Tensor).transpose(1,0))#50,1,1600
pose = Variable(target.type(Tensor))#1,50,18,2
pose=pose.view(1,50,36)
# Adversarial ground truths
# frame_valid = Variable(Tensor(np.ones((1,50))),requires_grad=False)
# frame_fake_gt = Variable(Tensor(np.zeros((1,50))),requires_grad=False)
# seq_valid = Variable(Tensor(np.ones((1,1))),requires_grad=False)
# seq_fake_gt = Variable(Tensor(np.zeros((1,1))),requires_grad=False)
# ------------------
# Train Generators
# ------------------
#generator.eval()
#optimizer_G.zero_grad()
# GAN loss
fake = generator(audio)
loss_pixel = criterion_pixelwise(fake, pose)
total_loss+=loss_pixel.item()
fake = fake.contiguous().cpu().detach().numpy()#1,50,36
fake = fake.reshape([50,36])
if(count <= counter):
write(output_dir+"/audio/{}.wav".format(i),16000,scaled)
real_coors = pose.cpu().numpy()
#print(real_coors.shape)
fake_coors = fake
real_coors = real_coors.reshape([-1,18,2])
fake_coors = fake_coors.reshape([-1,18,2])
real_coors[:,:,0] = (real_coors[:,:,0]+1) * 320
real_coors[:,:,1] = (real_coors[:,:,1]+1 ) * 180
real_coors = real_coors.astype(int)
fake_coors[:,:,0] = (fake_coors[:,:,0]+1) * 320
fake_coors[:,:,1] = (fake_coors[:,:,1]+1 ) * 180
fake_coors = fake_coors.astype(int)
save_2_batch_images(real_coors,fake_coors,batch_num=count,save_dir_start=output_dir)
count += 1
final_loss=total_loss/count
print("final_loss:",final_loss)