Skip to content

Commit abc7bb3

Browse files
committed
Update hydra
1 parent 3e2914f commit abc7bb3

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

eval.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2021 Sangchun Ha
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
import torch
13+
import numpy as np
14+
import random
15+
import hydra
16+
17+
from hydra.core.config_store import ConfigStore
18+
from omegaconf import OmegaConf, DictConfig
19+
from evaluator.evaluator import Evaluator
20+
from model_builder import load_test_model
21+
from data import MelSpectrogramConfig
22+
from evaluator import EvaluateConfig
23+
from data.data_loader import (
24+
SpectrogramDataset,
25+
AudioDataLoader,
26+
)
27+
from vocabulary import (
28+
load_label,
29+
load_dataset,
30+
)
31+
32+
33+
cs = ConfigStore.instance()
34+
cs.store(group="audio", name="melspectrogram", node=MelSpectrogramConfig, package="audio")
35+
cs.store(group="audio", name="filterbank", node=MelSpectrogramConfig, package="audio")
36+
cs.store(group="audio", name="mfcc", node=MelSpectrogramConfig, package="audio")
37+
cs.store(group="audio", name="spectrogram", node=MelSpectrogramConfig, package="audio")
38+
cs.store(group="eval", name="default", node=EvaluateConfig, package="eval")
39+
40+
41+
@hydra.main(config_path='configs', config_name='eval')
42+
def main(config: DictConfig) -> None:
43+
print(OmegaConf.to_yaml(config))
44+
45+
torch.manual_seed(config.eval.seed)
46+
torch.cuda.manual_seed_all(config.eval.seed)
47+
np.random.seed(config.eval.seed)
48+
random.seed(config.eval.seed)
49+
50+
use_cuda = config.eval.cuda and torch.cuda.is_available()
51+
device = torch.device('cuda' if use_cuda else 'cpu')
52+
53+
char2id, id2char = load_label(config.eval.label_path, config.eval.blank_id)
54+
audio_paths, transcripts, _, _ = load_dataset(config.eval.dataset_path, config.eval.mode)
55+
56+
test_dataset = SpectrogramDataset(
57+
config.eval.audio_path,
58+
audio_paths,
59+
transcripts,
60+
config.audio.sampling_rate,
61+
config.audio.n_mel,
62+
config.audio.frame_length,
63+
config.audio.frame_stride,
64+
config.audio.extension,
65+
config.train.sos_id,
66+
config.train.eos_id,
67+
)
68+
test_loader = AudioDataLoader(
69+
test_dataset,
70+
batch_size=config.eval.batch_size,
71+
num_workers=config.eval.num_workers,
72+
)
73+
74+
model = load_test_model(config, device)
75+
76+
print('Start Test !!!')
77+
78+
evaluator = Evaluator(config, device, test_loader, id2char)
79+
evaluator.evaluate(model)
80+
81+
82+
if __name__ == "__main__":
83+
main()

main.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2021 Sangchun Ha
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
import torch
13+
import torch.optim as optim
14+
import numpy as np
15+
import random
16+
import os
17+
import hydra
18+
import warnings
19+
20+
from hydra.core.config_store import ConfigStore
21+
from omegaconf import OmegaConf, DictConfig
22+
from trainer.trainer import train
23+
from model_builder import build_model
24+
from data.data_loader import (
25+
SpectrogramDataset,
26+
BucketingSampler,
27+
AudioDataLoader,
28+
)
29+
from vocabulary import (
30+
load_label,
31+
load_dataset,
32+
)
33+
from data import MelSpectrogramConfig
34+
from models.las import (
35+
ListenAttendSpellConfig,
36+
JointCTCAttentionLASConfig,
37+
)
38+
from models.deepspeech2 import DeepSpeech2Config
39+
from trainer import (
40+
ListenAttendSpellTrainConfig,
41+
DeepSpeech2TrainConfig,
42+
)
43+
44+
45+
cs = ConfigStore.instance()
46+
cs.store(group="audio", name="filterbank", node=MelSpectrogramConfig, package="audio")
47+
cs.store(group="audio", name="mfcc", node=MelSpectrogramConfig, package="audio")
48+
cs.store(group="audio", name="spectrogram", node=MelSpectrogramConfig, package="audio")
49+
cs.store(group="audio", name="melspectrogram", node=MelSpectrogramConfig, package="audio")
50+
cs.store(group="model", name="las", node=ListenAttendSpellConfig, package="model")
51+
cs.store(group="model", name="joint_ctc_attention_las", node=JointCTCAttentionLASConfig, package="model")
52+
cs.store(group="model", name="deepspeech2", node=DeepSpeech2Config, package="model")
53+
cs.store(group="train", name="las_train", node=ListenAttendSpellTrainConfig, package="train")
54+
cs.store(group="train", name="deepspeech2_train", node=DeepSpeech2TrainConfig, package="train")
55+
56+
57+
@hydra.main(config_path='configs', config_name='train')
58+
def main(config: DictConfig) -> None:
59+
warnings.filterwarnings('ignore')
60+
print(OmegaConf.to_yaml(config))
61+
62+
torch.manual_seed(config.train.seed)
63+
torch.cuda.manual_seed_all(config.train.seed)
64+
np.random.seed(config.train.seed)
65+
random.seed(config.train.seed)
66+
67+
use_cuda = config.train.cuda and torch.cuda.is_available()
68+
device = torch.device('cuda' if use_cuda else 'cpu')
69+
70+
char2id, id2char = load_label(config.train.label_path, config.train.blank_id)
71+
train_audio_paths, train_transcripts, valid_audio_paths, valid_transcripts = load_dataset(config.train.dataset_path, config.train.mode)
72+
73+
train_dataset = SpectrogramDataset(
74+
config.train.audio_path,
75+
train_audio_paths,
76+
train_transcripts,
77+
config.audio.sampling_rate,
78+
config.audio.n_mfcc if config.audio.feature_extraction == 'mfcc' else config.audio.n_mel,
79+
config.audio.frame_length,
80+
config.audio.frame_stride,
81+
config.audio.extension,
82+
config.audio.feature_extraction,
83+
config.audio.normalize,
84+
config.audio.spec_augment,
85+
config.audio.freq_mask_parameter,
86+
config.audio.num_time_mask,
87+
config.audio.num_freq_mask,
88+
config.train.sos_id,
89+
config.train.eos_id,
90+
)
91+
92+
train_sampler = BucketingSampler(train_dataset, batch_size=config.train.batch_size)
93+
train_loader = AudioDataLoader(
94+
train_dataset,
95+
batch_sampler=train_sampler,
96+
num_workers=config.train.num_workers,
97+
)
98+
99+
valid_dataset = SpectrogramDataset(
100+
config.train.audio_path,
101+
valid_audio_paths,
102+
valid_transcripts,
103+
config.audio.sampling_rate,
104+
config.audio.n_mfcc if config.audio.feature_extraction == 'mfcc' else config.audio.n_mel,
105+
config.audio.frame_length,
106+
config.audio.frame_stride,
107+
config.audio.extension,
108+
config.audio.feature_extraction,
109+
config.audio.normalize,
110+
config.audio.spec_augment,
111+
config.audio.freq_mask_parameter,
112+
config.audio.num_time_mask,
113+
config.audio.num_freq_mask,
114+
config.train.sos_id,
115+
config.train.eos_id,
116+
)
117+
valid_sampler = BucketingSampler(valid_dataset, batch_size=config.train.batch_size)
118+
valid_loader = AudioDataLoader(
119+
valid_dataset,
120+
batch_sampler=valid_sampler,
121+
num_workers=config.train.num_workers,
122+
)
123+
124+
model = build_model(config, device)
125+
126+
optimizer = optim.Adam(model.parameters(), lr=config.train.lr)
127+
128+
print('Start Train !!!')
129+
for epoch in range(0, config.train.epochs):
130+
train(config, model, device, train_loader, valid_loader, train_sampler, optimizer, epoch, id2char, epoch)
131+
132+
if epoch % 2 == 0:
133+
torch.save(model, os.path.join(os.getcwd(), config.train.model_save_path + str(epoch) + '.pt'))
134+
135+
136+
if __name__ == "__main__":
137+
main()

0 commit comments

Comments
 (0)