Skip to content

Commit

Permalink
修改默认参数和修复预测错误
Browse files Browse the repository at this point in the history
  • Loading branch information
yeyupiaoling committed Aug 30, 2023
1 parent 5761547 commit 03b8fad
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 21 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@

# 模型测试表

| 模型 | Params(M) | 预处理方法 | 数据集 | 类别数量 | 准确率 | 获取模型 |
|:------------:|:---------:|:-----:|:------------:|:----:|:-------:|:--------:|
| ResNetSE | 7.8 | Flank | UrbanSound8K | 10 | 0.98863 | 加入知识星球获取 |
| CAMPPlus | 7.1 | Flank | UrbanSound8K | 10 | 0.97727 | 加入知识星球获取 |
| ERes2Net | 6.6 | Flank | UrbanSound8K | 10 | 0.96590 | 加入知识星球获取 |
| PANNS(CNN10) | 5.2 | Flank | UrbanSound8K | 10 | 0.96590 | 加入知识星球获取 |
| Res2Net | 5.0 | Flank | UrbanSound8K | 10 | 0.94318 | 加入知识星球获取 |
| TDNN | 2.6 | Flank | UrbanSound8K | 10 | 0.92045 | 加入知识星球获取 |
| EcapaTdnn | 6.1 | Flank | UrbanSound8K | 10 | 0.91876 | 加入知识星球获取 |
| 模型 | Params(M) | 预处理方法 | 数据集 | 类别数量 | 准确率 | 获取模型 |
|:------------:|:---------:|:-----:|:-----------------:|:----:|:-------:|:--------:|
| ResNetSE | 7.8 | Flank | UrbanSound8K | 10 | 0.98863 | 加入知识星球获取 |
| CAMPPlus | 7.1 | Flank | UrbanSound8K | 10 | 0.97727 | 加入知识星球获取 |
| ERes2Net | 6.6 | Flank | UrbanSound8K | 10 | 0.96590 | 加入知识星球获取 |
| PANNS(CNN10) | 5.2 | Flank | UrbanSound8K | 10 | 0.96590 | 加入知识星球获取 |
| Res2Net | 5.0 | Flank | UrbanSound8K | 10 | 0.94318 | 加入知识星球获取 |
| TDNN | 2.6 | Flank | UrbanSound8K | 10 | 0.92045 | 加入知识星球获取 |
| EcapaTdnn | 6.1 | Flank | UrbanSound8K | 10 | 0.91876 | 加入知识星球获取 |
| CAMPPlus | 6.1 | Flank | CN-Celeb和VoxCeleb | 2 | 0.99320 | 加入知识星球获取 |
| ResNetSE | 9.1 | Flank | CN-Celeb和VoxCeleb | 2 | | 加入知识星球获取 |

## 安装环境

Expand Down
1 change: 1 addition & 0 deletions configs/eres2net.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ model_conf:
num_class: null
# 所使用的池化层,支持ASP、SAP、TSP、TAP、TSTP
pooling_type: 'TSTP'
embd_dim: 192

train_conf:
# 是否开启自动混合精度
Expand Down
4 changes: 2 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/ecapa_tdnn.yml', "配置文件")
add_arg('configs', str, 'configs/cam++.yml', "配置文件")
add_arg("use_gpu", bool, True, "是否使用GPU评估模型")
add_arg('save_matrix_path', str, 'output/images/', "保存混合矩阵的路径")
add_arg('resume_model', str, 'models/EcapaTdnn_Fbank/best_model/', "模型的路径")
add_arg('resume_model', str, 'models/CAMPPlus_Fbank/best_model/', "模型的路径")
args = parser.parse_args()
print_arguments(args=args)

Expand Down
4 changes: 2 additions & 2 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg('use_gpu', bool, True, '是否使用GPU预测')
add_arg('audio_path', str, 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav', '音频路径')
add_arg('model_path', str, 'models/EcapaTdnn_Fbank/best_model/', '导出的预测模型文件路径')
add_arg('model_path', str, 'models/CAMPPlus_Fbank/best_model/', '导出的预测模型文件路径')
args = parser.parse_args()
print_arguments(args=args)

Expand Down
4 changes: 2 additions & 2 deletions infer_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg('use_gpu', bool, True, '是否使用GPU预测')
add_arg('record_seconds', int, 3, '录音长度')
add_arg('model_path', str, 'models/EcapaTdnn_Fbank/best_model/', '导出的预测模型文件路径')
add_arg('model_path', str, 'models/CAMPPlus_Fbank/best_model/', '导出的预测模型文件路径')
args = parser.parse_args()
print_arguments(args=args)

Expand Down
2 changes: 1 addition & 1 deletion macls/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.1"
__version__ = "0.4.2"
# 项目支持的模型
SUPPORT_MODEL = ['EcapaTdnn', 'PANNS_CNN6', 'PANNS_CNN10', 'PANNS_CNN14', 'Res2Net', 'ResNetSE', 'TDNN', 'ERes2Net',
'CAMPPlus']
3 changes: 1 addition & 2 deletions macls/models/campplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,5 @@ def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)

out = self.fc(x)
x = self.fc(x)
return x
4 changes: 2 additions & 2 deletions macls/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def __init__(self,
elif self.configs.use_model == 'TDNN':
self.predictor = TDNN(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
elif self.configs.use_model == 'ERes2Net':
self.model = ERes2Net(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
self.predictor = ERes2Net(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
elif self.configs.use_model == 'CAMPPlus':
self.model = CAMPPlus(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
self.predictor = CAMPPlus(input_size=self._audio_featurizer.feature_dim, **self.configs.model_conf)
else:
raise Exception(f'{self.configs.use_model} 模型不存在!')
self.predictor.to(self.device)
Expand Down
2 changes: 2 additions & 0 deletions macls/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __setup_dataloader(self, is_train=False):
mode='eval')
self.test_loader = DataLoader(dataset=self.test_dataset,
collate_fn=collate_fn,
shuffle=True,
batch_size=self.configs.dataset_conf.eval_conf.batch_size,
num_workers=self.configs.dataset_conf.dataLoader.num_workers)

Expand Down Expand Up @@ -325,6 +326,7 @@ def __train_epoch(self, epoch_id, local_rank, writer, nranks=0):

# 多卡训练只使用一个进程打印
if batch_id % self.configs.train_conf.log_interval == 0 and local_rank == 0:
batch_id = batch_id + 1
# 计算每秒训练数据量
train_speed = self.configs.dataset_conf.dataLoader.batch_size / (sum(train_times) / len(train_times) / 1000)
# 计算剩余时间
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('configs', str, 'configs/ecapa_tdnn.yml', '配置文件')
add_arg('configs', str, 'configs/cam++.yml', '配置文件')
add_arg("local_rank", int, 0, '多卡训练需要的参数')
add_arg("use_gpu", bool, True, '是否使用GPU训练')
add_arg('save_model_path', str, 'models/', '模型保存的路径')
Expand Down

0 comments on commit 03b8fad

Please sign in to comment.