Skip to content

Commit

Permalink
更新模型和修改预处理方式
Browse files Browse the repository at this point in the history
  • Loading branch information
yeyupiaoling committed Apr 26, 2022
1 parent 017d398 commit abd1122
Show file tree
Hide file tree
Showing 15 changed files with 557 additions and 407 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ __pycache__/
.idea/
dataset/
models/
log/
infer_audio.wav
212 changes: 109 additions & 103 deletions README.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ def get_data_list(audio_path, list_path):

f_train = open(os.path.join(list_path, 'train_list.txt'), 'w')
f_test = open(os.path.join(list_path, 'test_list.txt'), 'w')
f_label = open(os.path.join(list_path, 'label_list.txt'), 'w')

for i in range(len(audios)):
f_label.write(f'{audios[i]}\n')
sounds = os.listdir(os.path.join(audio_path, audios[i]))
for sound in sounds:
if '.wav' not in sound:continue
Expand All @@ -24,7 +26,7 @@ def get_data_list(audio_path, list_path):
f_train.write('%s\t%d\n' % (sound_path, i))
sound_sum += 1
print("Audio:%d/%d" % (i + 1, len(audios)))

f_label.close()
f_test.close()
f_train.close()

Expand Down
66 changes: 66 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import argparse
import functools

import numpy as np
import torch
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader

from utils.ecapa_tdnn import EcapaTdnn
from utils.reader import CustomDataset, collate_fn
from utils.utility import add_arguments, print_arguments, plot_confusion_matrix

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 32, '训练的批量大小')
add_arg('num_workers', int, 4, '读取数据的线程数量')
add_arg('num_classes', int, 10, '分类的类别数量')
add_arg('learning_rate', float, 1e-3, '初始学习率的大小')
add_arg('test_list_path', str, 'dataset/test_list.txt', '测试数据的数据列表路径')
add_arg('label_list_path', str, 'dataset/label_list.txt', '标签列表路径')
add_arg('model_path', str, 'models/model.pth', '模型保存的路径')
args = parser.parse_args()


def evaluate():
test_dataset = CustomDataset(args.test_list_path, model='eval')
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=args.num_workers)
# 获取分类标签
with open(args.label_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
class_labels = [l.replace('\n', '') for l in lines]
# 获取模型
device = torch.device("cuda")
model = EcapaTdnn(num_classes=args.num_classes)
model.to(device)
model.load_state_dict(torch.load(args.model_path))
model.eval()

accuracies, preds, labels = [], [], []
for batch_id, (spec_mag, label) in enumerate(test_loader):
spec_mag = spec_mag.to(device)
label = label.numpy()
output = model(spec_mag)
output = output.data.cpu().numpy()
pred = np.argmax(output, axis=1)
preds.extend(pred.tolist())
labels.extend(label.tolist())
acc = np.mean((pred == label).astype(int))
accuracies.append(acc.item())
acc = float(sum(accuracies) / len(accuracies))
cm = confusion_matrix(labels, preds)
FP = cm.sum(axis=0) - np.diag(cm)
FN = cm.sum(axis=1) - np.diag(cm)
TP = np.diag(cm)
TN = cm.sum() - (FP + FN + TP)
# 精确率
precision = TP / (TP + FP + 1e-6)
# 召回率
recall = TP / (TP + FN + 1e-6)
print('分类准确率: {:.4f}, 平均精确率: {:.4f}, 平均召回率: {:.4f}'.format(acc, np.mean(precision), np.mean(recall)))
plot_confusion_matrix(cm=cm, save_path='log/混淆矩阵_eval.png', class_labels=class_labels)


if __name__ == '__main__':
print_arguments(args)
evaluate()
Binary file added images/image1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 29 additions & 25 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,47 @@
import argparse
import functools

import librosa
import numpy as np
import torch

# 加载模型
model_path = 'models/resnet34.pth'
from utils.ecapa_tdnn import EcapaTdnn
from utils.reader import load_audio
from utils.utility import add_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('audio_path', str, 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav', '图片路径')
add_arg('num_classes', int, 10, '分类的类别数量')
add_arg('label_list_path', str, 'dataset/label_list.txt', '标签列表路径')
add_arg('model_path', str, 'models/model.pth', '模型保存的路径')
args = parser.parse_args()


# 获取分类标签
with open(args.label_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
class_labels = [l.replace('\n', '') for l in lines]
# 获取模型
device = torch.device("cuda")
model = torch.jit.load(model_path)
model = EcapaTdnn(num_classes=args.num_classes)
model.to(device)
model.load_state_dict(torch.load(args.model_path))
model.eval()


# 读取音频数据
def load_data(data_path):
# 读取音频
wav, sr = librosa.load(data_path, sr=16000)
spec_mag = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256).astype(np.float32)
mean = np.mean(spec_mag, 0, keepdims=True)
std = np.std(spec_mag, 0, keepdims=True)
spec_mag = (spec_mag - mean) / (std + 1e-5)
spec_mag = spec_mag[np.newaxis, np.newaxis, :]
spec_mag = spec_mag.astype('float32')
return spec_mag


def infer(audio_path):
data = load_data(audio_path)
def infer():
data = load_audio(args.audio_path, mode='infer')
data = data[np.newaxis, :]
data = torch.tensor(data, dtype=torch.float32, device=device)
# 执行预测
output = model(data)
result = torch.nn.functional.softmax(output)
result = torch.nn.functional.softmax(output, dim=-1)
result = result.data.cpu().numpy()
print(result)
# 显示图片并输出结果最大的label
lab = np.argsort(result)[0][-1]
return lab
print(f'音频:{args.audio_path} 的预测结果标签为:{class_labels[lab]}')


if __name__ == '__main__':
# 要预测的音频文件
path = 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav'
label = infer(path)
print('音频:%s 的预测结果标签为:%d' % (path, label))
infer()
60 changes: 31 additions & 29 deletions infer_record.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
import argparse
import functools
import wave
import librosa

import numpy as np
import pyaudio
import torch

# 加载模型
model_path = 'models/resnet34.pth'
from utils.ecapa_tdnn import EcapaTdnn
from utils.reader import load_audio
from utils.utility import add_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_classes', int, 10, '分类的类别数量')
add_arg('label_list_path', str, 'dataset/label_list.txt', '标签列表路径')
add_arg('model_path', str, 'models/model.pth', '模型保存的路径')
args = parser.parse_args()


# 获取分类标签
with open(args.label_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
class_labels = [l.replace('\n', '') for l in lines]
# 获取模型
device = torch.device("cuda")
model = torch.jit.load(model_path)
model = EcapaTdnn(num_classes=args.num_classes)
model.to(device)
model.load_state_dict(torch.load(args.model_path))
model.eval()

# 录音参数
Expand All @@ -28,19 +46,6 @@
frames_per_buffer=CHUNK)


# 读取音频数据
def load_data(data_path):
# 读取音频
wav, sr = librosa.load(data_path, sr=16000)
spec_mag = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256).astype(np.float32)
mean = np.mean(spec_mag, 0, keepdims=True)
std = np.std(spec_mag, 0, keepdims=True)
spec_mag = (spec_mag - mean) / (std + 1e-5)
spec_mag = spec_mag[np.newaxis, np.newaxis, :]
spec_mag = spec_mag.astype('float32')
return spec_mag


# 获取录音数据
def record_audio():
print("开始录音......")
Expand All @@ -63,29 +68,26 @@ def record_audio():

# 预测
def infer(audio_path):
data = load_data(audio_path)
data = load_audio(audio_path, mode='infer')
data = data[np.newaxis, :]
data = torch.tensor(data, dtype=torch.float32, device=device)
# 执行预测
output = model(data)
result = torch.nn.functional.softmax(output)
result = torch.nn.functional.softmax(output, dim=-1)
result = result.data.cpu().numpy()
print(result)
# 显示图片并输出结果最大的label
lab = np.argsort(result)[0][-1]
return lab
return class_labels[lab]


if __name__ == '__main__':
try:
while True:
try:
# 加载数据
audio_path = record_audio()
# 获取预测结果
label = infer(audio_path)
print('预测的标签为:%d' % label)
except:
pass
# 加载数据
audio_path = record_audio()
# 获取预测结果
label = infer(audio_path)
print(f'预测的标签为:{label}')
except Exception as e:
print(e)
stream.stop_stream()
Expand Down
41 changes: 0 additions & 41 deletions reader.py

This file was deleted.

Loading

0 comments on commit abd1122

Please sign in to comment.