-
Notifications
You must be signed in to change notification settings - Fork 1
/
arguments.py
executable file
·88 lines (75 loc) · 2.85 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from itertools import chain
import torch
def get_args():
system_args = {
# expeirment info
'project' : 'ExU-Net',
'name' : 'ExU-Net',
'tags' : ['ExU-Net'],
'description' : 'ExU-Net',
# local
'path_logging' : '/results',
# VoxCeleb1 DB
'path_vox1_train' : '/datas/VoxCeleb1/train',
'path_vox1_test' : '/datas/VoxCeleb1/test',
'path_vox1_trials' : '/datas/VoxCeleb1/trials.txt',
# musan DB
'path_musan' : '/datas/musan',
# device
'num_workers' : 20,
'usable_gpu' : '0,1',
'tqdm_ncols' : 90,
'path_scripts' : os.path.dirname(os.path.realpath(__file__))
}
experiment_args = {
# env
'epoch' : 500,
'batch_size' : 120,
'number_cycle' : 80,
'number_iteration_for_log' : 50,
'rand_seed' : 1234,
'flag_reproduciable' : True,
# train process
'do_train_feature_enhancement' : True,
'do_train_code_enhancement' : True,
# optimizer
'optimizer' : 'adam',
'amsgrad' : True,
'learning_rate_scheduler' : 'step',
'lr_start' : 1e-3,
'lr_end' : 1e-7,
'weigth_decay' : 1e-4,
# criterion
'classification_loss' : 'softmax',
'enhancement_loss' : 'mse',
'code_enhacement_loss' : 'angleproto',
'weight_classification_loss' : 1,
'weight_code_enhancement_loss' : 1,
'weight_feature_enhancement_loss' : 1,
# model
'first_kernel_size' : 7,
'first_stride_size' : (2,1),
'first_padding_size' : 3,
'l_channel' : [16, 32, 64, 128],
'l_num_convblocks' : [3, 4, 6, 3],
'code_dim' : 128,
'stride' : [1,2,2,1],
# data
'nb_utt_per_spk' : 2,
'max_seg_per_spk' : 500,
'winlen' : 400,
'winstep' : 160,
'train_frame' : 254,
'nfft' : 1024,
'samplerate' : 16000,
'nfilts' : 64,
'premphasis' : 0.97,
'winfunc' : torch.hamming_window,
'test_frame' : 382
}
# set args (system_args + experiment_args)
args = {}
for k, v in chain(system_args.items(), experiment_args.items()):
args[k] = v
return args