-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
126 lines (116 loc) · 4.7 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import torch
import argparse
from torch.utils.data import DataLoader
from dataloader import MusicDataset, CustomDataset
from utils import save_midis, to_binary
from model import CycleGAN
from datetime import datetime
import logging
logger = logging.getLogger()
# from torch.cuda.amp import autocast as autocast # for reducing the allocation of GPU
def get_time():
now_time = datetime.now()
return str(now_time.year) + '-' + str(now_time.month) + '-' + str(now_time.day) + '-' + str(
now_time.hour) + '-' + str(now_time.minute) + '-' + str(now_time.second)
def make_parses():
parser = argparse.ArgumentParser(description='Trainer')
parser.add_argument(
'--data-dir',
default=None,
type=str
)
parser.add_argument(
'--model-dir',
default = r'saved_models/JC/JC_itr_168000_G_1.384160_D_0.457109.pth',
# default = r'saved_models/JC/JC_itr_83000_G_1.379626_D_0.459392.pth',
type=str
)
parser.add_argument(
'--batch-size',
default=1,
type=int
)
parser.add_argument(
'--model-name',
default='JC',
type=str,
help='Optional: CP, JC, JP'
)
parser.add_argument(
'--test-mode',
default='A2B',
type=str
)
return parser.parse_args()
def test():
# JC CP JP.
args = make_parses()
model_name = args.model_name
mode = args.test_mode
model_dir = args.model_dir
data_dir = args.data_dir if args.data_dir else os.path.join(os.getcwd(), 'data' + os.sep)
now_time = datetime.now()
now_mon = now_time.month
now_day = now_time.day
now_hour = now_time.hour
now_minute = now_time.minute
now_second = now_time.second
save_dir = os.path.join(os.getcwd(), 'test'+ '-' + model_name +
'-'+ str(now_mon)+ '-' +str(now_day)+ '-' +str(now_hour)+ '-' +str(now_minute)+ '-' + str(now_second)+
os.sep)
# save_dir = os.path.join(os.getcwd(), 'test' + os.sep) #风格迁移后音乐保存路径
if not os.path.exists(save_dir):
os.makedirs(save_dir)
logging.basicConfig(level=logging.INFO, # 控制台打印的日志级别
filename=save_dir + '/{}.log'.format('test_' + model_name+'_' + get_time()),
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format='%(asctime)s - %(message)s' # 日志格式
)
if args.data_dir is None:
music_dataset = MusicDataset(data_dir, train_mode='CP', data_mode='full', is_train='test')
else:
music_dataset = CustomDataset(data_dir)
music_dataloader = DataLoader(
music_dataset, batch_size=1, shuffle=False, num_workers=0)
print("test dataset contains {} items".format(len(music_dataset)))
# ------- 3. define model --------
net = CycleGAN(mode=mode)
checkpoint = torch.load(args.model_dir)
if 'model_name' in checkpoint.keys():
assert checkpoint['model_name'] == model_name
if 'state_dict' in checkpoint.keys():
net.load_state_dict(checkpoint['state_dict'])
else:
net.load_state_dict(checkpoint)
if torch.cuda.is_available():
net.cuda()
net.eval()
# ------- 5. training process --------
print("---start testing...")
logger.info('load model:{}'.format(model_dir))
for i, data in enumerate(music_dataloader):
# with autocast():
real_a, real_b, real_mixed = data['bar_a'], data['bar_b'], data['bar_mixed']
real_a = torch.FloatTensor(real_a)
real_b = torch.FloatTensor(real_b)
real_mixed = torch.FloatTensor(real_mixed)
if torch.cuda.is_available():
real_a = real_a.cuda()
real_b = real_b.cuda()
real_mixed = real_mixed.cuda()
transfered, cycle = net(real_a, real_b, real_mixed)
transfered = transfered.permute(0, 2, 3, 1) # torch.permute函数用于张量维度换位,即依次将第0,2,3,1维的张量替换当前维度(依次为 0,1,2,3)张量
cycle = cycle.permute(0, 2, 3, 1)
trans_np = to_binary(transfered.detach().cpu().numpy())
cycle_np = to_binary(cycle.detach().cpu().numpy())
name = music_dataset._get_name(data['baridx'])
print(type(name))
print('save to '+ save_dir + name + '_transfered.mid')
save_midis(trans_np, save_dir + name + '_transfered.mid')
save_midis(real_a.permute(0, 2, 3, 1).detach().cpu().numpy(),
save_dir + name + '_origin.mid')
save_midis(cycle_np, save_dir + name + '_cycle.mid')
if __name__ == '__main__':
test()