Skip to content

Commit e64e8c4

Browse files
committed
add my study
1 parent 522f577 commit e64e8c4

13 files changed

+519
-0
lines changed

configs/audio/melspectrogram.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
extension: pcm
2+
sampling_rate: 16000
3+
n_mel: 80
4+
frame_length: 20
5+
frame_stride: 10

configs/eval.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
defaults:
2+
- audio: melspectrogram
3+
- eval: default

configs/eval/default.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
dataset_path: ''
2+
audio_path: ''
3+
label_path: D:/label/aihub_labels.csv
4+
model_path: ''
5+
save_transcripts_path: ''
6+
print_interval: 10
7+
num_vocabs: 2001
8+
pad_id: 0
9+
sos_id: 1
10+
eos_id: 2
11+
blank_id: 2000
12+
batch_size: 4
13+
num_workers: 4
14+
cuda: True
15+
seed: 22
16+
mode: eval

configs/model/deepspeech2.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
architecture: deepspeech2
2+
input_size: 80
3+
hidden_size: 512
4+
num_layers: 3
5+
dropout: 0.3
6+
bidirectional: True
7+
rnn_type: gru
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
architecture: las
2+
input_size: 80
3+
encoder_hidden_size: 256
4+
decoder_hidden_size: 512
5+
encoder_layers: 3
6+
decoder_layers: 2
7+
dropout: 0.3
8+
bidirectional: True
9+
rnn_type: lstm
10+
teacher_forcing_ratio: 1.0
11+
use_joint_ctc_attention: False
12+
max_len: 120
13+
attn_mechanism: location
14+
smoothing: False
15+
ctc_weight: 0.2
16+
cross_entropy_weight: 0.8

configs/model/las.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
architecture: las
2+
input_size: 80
3+
encoder_hidden_size: 256
4+
decoder_hidden_size: 512
5+
encoder_layers: 3
6+
decoder_layers: 2
7+
dropout: 0.3
8+
bidirectional: True
9+
rnn_type: lstm
10+
teacher_forcing_ratio: 1.0
11+
use_joint_ctc_attention: False
12+
max_len: 120
13+
attn_mechanism: location
14+
smoothing: False

configs/train.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
defaults:
2+
- audio: melspectrogram
3+
- model: joint_ctc_attention_las
4+
- train: las_train

configs/train/deepspeech2_train.yaml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Dataset
2+
dataset_path: D:/dataset/transcripts.txt
3+
audio_path: E:/KsponSpeech
4+
label_path: D:/label/aihub_labels.csv
5+
model_save_path: deepspeech2_model.pt
6+
7+
# vocabulary
8+
num_vocabs: 2001
9+
pad_id: 0
10+
sos_id: 1
11+
eos_id: 2
12+
blank_id: 2000
13+
14+
# trainer
15+
batch_size: 4
16+
num_workers: 4
17+
epochs: 50
18+
lr: 1e-06
19+
print_interval: 10
20+
21+
# System
22+
cuda: True
23+
seed: 22
24+
mode: train

configs/train/las_train.yaml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Dataset
2+
dataset_path: D:/dataset/transcripts.txt
3+
audio_path: E:/KsponSpeech
4+
label_path: D:/label/aihub_labels.csv
5+
model_save_path: las_model.pt
6+
7+
# vocabulary
8+
num_vocabs: 2001
9+
pad_id: 0
10+
sos_id: 1
11+
eos_id: 2
12+
blank_id: 2000
13+
14+
# trainer
15+
batch_size: 4
16+
num_workers: 4
17+
epochs: 20
18+
lr: 1e-06
19+
print_interval: 10
20+
21+
# System
22+
cuda: True
23+
seed: 22
24+
mode: train

eval.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import numpy as np
3+
import random
4+
import hydra
5+
from hydra.core.config_store import ConfigStore
6+
from omegaconf import OmegaConf, DictConfig
7+
from evaluator.evaluator import Evaluator
8+
from model_builder import load_test_model
9+
from data.data_loader import (
10+
SpectrogramDataset,
11+
AudioDataLoader,
12+
)
13+
from vocabulary import (
14+
load_label,
15+
load_dataset,
16+
)
17+
18+
from data import MelSpectrogramConfig
19+
from evaluator import EvaluateConfig
20+
21+
22+
cs = ConfigStore.instance()
23+
cs.store(group="audio", name="melspectrogram", node=MelSpectrogramConfig, package="audio")
24+
cs.store(group="eval", name="default", node=EvaluateConfig, package="eval")
25+
26+
27+
@hydra.main(config_path='configs', config_name='eval')
28+
def main(config: DictConfig) -> None:
29+
print(OmegaConf.to_yaml(config))
30+
31+
torch.manual_seed(config.eval.seed)
32+
torch.cuda.manual_seed_all(config.eval.seed)
33+
np.random.seed(config.eval.seed)
34+
random.seed(config.eval.seed)
35+
36+
use_cuda = config.eval.cuda and torch.cuda.is_available()
37+
device = torch.device('cuda' if use_cuda else 'cpu')
38+
39+
char2id, id2char = load_label(config.eval.label_path, config.eval.blank_id)
40+
audio_paths, transcripts, _, _ = load_dataset(config.eval.dataset_path, config.eval.mode)
41+
42+
test_dataset = SpectrogramDataset(
43+
config.eval.audio_path,
44+
audio_paths,
45+
transcripts,
46+
config.audio.sampling_rate,
47+
config.audio.n_mel,
48+
config.audio.frame_length,
49+
config.audio.frame_stride,
50+
config.audio.extension,
51+
config.train.sos_id,
52+
config.train.eos_id,
53+
)
54+
test_loader = AudioDataLoader(
55+
test_dataset,
56+
batch_size=config.eval.batch_size,
57+
num_workers=config.eval.num_workers,
58+
)
59+
60+
model = load_test_model(config, device)
61+
62+
print('Start Inference !!!')
63+
64+
evaluator = Evaluator(config, device, test_loader, id2char)
65+
evaluator.evaluate(model)
66+
67+
68+
if __name__ == "__main__":
69+
main()

main.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import torch
2+
import torch.optim as optim
3+
import numpy as np
4+
import random
5+
import hydra
6+
from hydra.core.config_store import ConfigStore
7+
from omegaconf import OmegaConf, DictConfig
8+
from trainer.trainer import train
9+
from model_builder import build_model
10+
11+
from data.data_loader import (
12+
SpectrogramDataset,
13+
BucketingSampler,
14+
AudioDataLoader,
15+
)
16+
from vocabulary import (
17+
load_label,
18+
load_dataset,
19+
)
20+
21+
from data import MelSpectrogramConfig
22+
from models.las import (
23+
ListenAttendSpellConfig,
24+
JointCTCAttentionLASConfig,
25+
)
26+
from models.deepspeech2 import DeepSpeech2Config
27+
from trainer import (
28+
ListenAttendSpellTrainConfig,
29+
DeepSpeech2TrainConfig,
30+
)
31+
32+
33+
cs = ConfigStore.instance()
34+
cs.store(group="audio", name="melspectrogram", node=MelSpectrogramConfig, package="audio")
35+
cs.store(group="model", name="las", node=ListenAttendSpellConfig, package="model")
36+
cs.store(group="model", name="joint_ctc_attention_las", node=JointCTCAttentionLASConfig, package="model")
37+
cs.store(group="model", name="deepspeech2", node=DeepSpeech2Config, package="model")
38+
cs.store(group="train", name="las_train", node=ListenAttendSpellTrainConfig, package="train")
39+
cs.store(group="train", name="deepspeech2_train", node=DeepSpeech2TrainConfig, package="train")
40+
41+
42+
@hydra.main(config_path='configs', config_name='train')
43+
def main(config: DictConfig) -> None:
44+
print(OmegaConf.to_yaml(config))
45+
46+
torch.manual_seed(config.train.seed)
47+
torch.cuda.manual_seed_all(config.train.seed)
48+
np.random.seed(config.train.seed)
49+
random.seed(config.train.seed)
50+
51+
use_cuda = config.train.cuda and torch.cuda.is_available()
52+
device = torch.device('cuda' if use_cuda else 'cpu')
53+
54+
char2id, id2char = load_label(config.train.label_path, config.train.blank_id)
55+
train_audio_paths, train_transcripts, valid_audio_paths, valid_transcripts = load_dataset(config.train.dataset_path, config.train.mode)
56+
57+
train_dataset = SpectrogramDataset(
58+
config.train.audio_path,
59+
train_audio_paths,
60+
train_transcripts,
61+
config.audio.sampling_rate,
62+
config.audio.n_mel,
63+
config.audio.frame_length,
64+
config.audio.frame_stride,
65+
config.audio.extension,
66+
config.train.sos_id,
67+
config.train.eos_id,
68+
)
69+
70+
train_sampler = BucketingSampler(train_dataset, batch_size=config.train.batch_size)
71+
train_loader = AudioDataLoader(
72+
train_dataset,
73+
batch_sampler=train_sampler,
74+
num_workers=config.train.num_workers,
75+
)
76+
77+
valid_dataset = SpectrogramDataset(
78+
config.train.audio_path,
79+
valid_audio_paths,
80+
valid_transcripts,
81+
config.audio.sampling_rate,
82+
config.audio.n_mel,
83+
config.audio.frame_length,
84+
config.audio.frame_stride,
85+
config.audio.extension,
86+
config.train.sos_id,
87+
config.train.eos_id,
88+
)
89+
valid_sampler = BucketingSampler(valid_dataset, batch_size=config.train.batch_size)
90+
valid_loader = AudioDataLoader(
91+
valid_dataset,
92+
batch_sampler=valid_sampler,
93+
num_workers=config.train.num_workers,
94+
)
95+
96+
model = build_model(config, device)
97+
model = model.to(device)
98+
99+
optimizer = optim.Adam(model.parameters(), lr=config.train.lr)
100+
101+
print('Start Train !!!')
102+
for epoch in range(0, config.train.epochs):
103+
train(config, model, device, train_loader, valid_loader, train_sampler, optimizer, epoch, id2char)
104+
105+
torch.save(model.state_dict(), config.train.model_save_path)
106+
107+
108+
if __name__ == "__main__":
109+
main()

model_builder.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from models.las.encoder import Encoder
2+
from models.las.decoder import Decoder
3+
from models.las.model import ListenAttendSpell
4+
from models.deepspeech2.model import DeepSpeech2
5+
from omegaconf import DictConfig
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
11+
def build_model(config: DictConfig, device: torch.device):
12+
if config.model.architecture == 'las':
13+
return build_las_model(config, device)
14+
15+
elif config.model.architecture == 'deepspeech2':
16+
return build_ds2_model(config)
17+
18+
19+
def load_test_model(config: DictConfig, device: torch.device) -> nn.Module:
20+
model = torch.load(config.eval.model_path, map_location=lambda storage, loc: storage).to(device)
21+
22+
model.encoder.device = device
23+
model.decoder.device = device
24+
25+
return model
26+
27+
28+
def build_encoder(config: DictConfig) -> Encoder:
29+
return Encoder(
30+
config.train.num_vocabs,
31+
config.model.input_size,
32+
config.model.encoder_hidden_size,
33+
config.model.encoder_layers,
34+
config.model.dropout,
35+
config.model.bidirectional,
36+
config.model.rnn_type,
37+
config.model.use_joint_ctc_attention
38+
)
39+
40+
41+
def build_decoder(config: DictConfig, device: torch.device) -> Decoder:
42+
return Decoder(
43+
device,
44+
config.train.num_vocabs,
45+
config.model.decoder_hidden_size,
46+
config.model.decoder_hidden_size,
47+
config.model.decoder_layers,
48+
config.model.max_len,
49+
config.model.dropout,
50+
config.model.rnn_type,
51+
config.model.attn_mechanism,
52+
config.model.smoothing,
53+
config.train.sos_id,
54+
config.train.eos_id,
55+
)
56+
57+
58+
def build_las_model(config: DictConfig, device: torch.device) -> ListenAttendSpell:
59+
encoder = build_encoder(config)
60+
decoder = build_decoder(config, device)
61+
62+
return ListenAttendSpell(encoder, decoder)
63+
64+
65+
def build_ds2_model(config: DictConfig) -> DeepSpeech2:
66+
return DeepSpeech2(
67+
config.train.num_vocabs,
68+
config.model.input_size,
69+
config.model.hidden_size,
70+
config.model.num_layers,
71+
config.model.dropout,
72+
config.model.bidirectional,
73+
config.model.rnn_type,
74+
)

0 commit comments

Comments
 (0)