-
Notifications
You must be signed in to change notification settings - Fork 23
/
diff_txt_emb_clotho_4.yaml
143 lines (126 loc) · 3.93 KB
/
diff_txt_emb_clotho_4.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# @package _global_
sampling_rate: 48000
length: 131072
channels: 1
log_every_n_steps: 1000
encoder_max_length: 64
encoder_features: 768
model:
_target_: main.module_diff_txt_emb.Model
lr: 1e-4
lr_beta1: 0.95
lr_beta2: 0.999
lr_eps: 1e-6
lr_weight_decay: 1e-3
ema_beta: 0.995
ema_power: 0.7
model:
_target_: audio_diffusion_pytorch.AudioDiffusionConditional
in_channels: ${channels}
channels: 160
patch_size: 16
resnet_groups: 8
kernel_multiplier_downsample: 2
multipliers: [1, 2, 4, 4, 4, 4, 4]
factors: [2, 2, 2, 2, 2, 2]
num_blocks: [2, 2, 2, 2, 2, 2]
attentions: [0, 0, 0, 1, 1, 1, 1]
attention_heads: 8
attention_features: 64
attention_multiplier: 2
use_nearest_upsample: False
use_skip_scale: True
diffusion_type: v
diffusion_sigma_distribution:
_target_: audio_diffusion_pytorch.UniformDistribution
embedding_max_length: ${encoder_max_length}
embedding_features: ${encoder_features}
embedding_mask_proba: 0.1
embedder:
_target_: audio_diffusion_pytorch.T5Embedder
model: t5-base
max_length: ${encoder_max_length}
batch_size: 24
datamodule:
_target_: main.module_diff_txt_emb.Datamodule
dataset_train:
_target_: audio_data_pytorch.ClothoDataset
root: ${data_dir}
split: train
batch_size: ${batch_size}
preprocess_sample_rate: 48000
preprocess_transforms:
_target_: audio_data_pytorch.AllTransform
crop_size: 480000
stereo: True
transforms:
_target_: audio_data_pytorch.AllTransform
target_rate: ${sampling_rate}
random_crop_size: ${length}
mono: True
dataset_valid:
_target_: audio_data_pytorch.ClothoDataset
root: ${data_dir}
split: valid
batch_size: ${batch_size}
preprocess_sample_rate: 48000
preprocess_transforms:
_target_: audio_data_pytorch.AllTransform
crop_size: 480000
stereo: True
transforms:
_target_: audio_data_pytorch.AllTransform
target_rate: ${sampling_rate}
random_crop_size: ${length}
mono: True
num_workers: 8
pin_memory: True
callbacks:
rich_progress_bar:
_target_: pytorch_lightning.callbacks.RichProgressBar
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "valid_loss" # name of the logged metric which determines when model is improving
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
mode: "min" # can be "max" or "min"
verbose: False
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
filename: '{epoch:02d}-{valid_loss:.3f}'
model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
max_depth: 2
audio_samples_logger:
_target_: main.module_diff_txt_emb.SampleLogger
num_items: 3
channels: ${channels}
sampling_rate: ${sampling_rate}
length: ${length}
sampling_steps: [3,5,10,25,50,100]
embedding_scale: 15.0
use_ema_model: True
diffusion_sampler:
_target_: audio_diffusion_pytorch.VSampler
diffusion_schedule:
_target_: audio_diffusion_pytorch.LinearSchedule
loggers:
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
project: ${oc.env:WANDB_PROJECT}
entity: ${oc.env:WANDB_ENTITY}
# offline: False # set True to store all logs only locally
job_type: "train"
group: ""
save_dir: ${logs_dir}
trainer:
_target_: pytorch_lightning.Trainer
gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`
precision: 32 # Precision used for tensors, default `32`
accelerator: null # `ddp` GPUs train individually and sync gradients, default `None`
min_epochs: 0
max_epochs: -1
enable_model_summary: False
log_every_n_steps: 1 # Logs metrics every N batches
check_val_every_n_epoch: null
limit_val_batches: 20
val_check_interval: ${log_every_n_steps}