-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
017d398
commit abd1122
Showing
15 changed files
with
557 additions
and
407 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ __pycache__/ | |
.idea/ | ||
dataset/ | ||
models/ | ||
log/ | ||
infer_audio.wav |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import argparse | ||
import functools | ||
|
||
import numpy as np | ||
import torch | ||
from sklearn.metrics import confusion_matrix | ||
from torch.utils.data import DataLoader | ||
|
||
from utils.ecapa_tdnn import EcapaTdnn | ||
from utils.reader import CustomDataset, collate_fn | ||
from utils.utility import add_arguments, print_arguments, plot_confusion_matrix | ||
|
||
parser = argparse.ArgumentParser(description=__doc__) | ||
add_arg = functools.partial(add_arguments, argparser=parser) | ||
add_arg('batch_size', int, 32, '训练的批量大小') | ||
add_arg('num_workers', int, 4, '读取数据的线程数量') | ||
add_arg('num_classes', int, 10, '分类的类别数量') | ||
add_arg('learning_rate', float, 1e-3, '初始学习率的大小') | ||
add_arg('test_list_path', str, 'dataset/test_list.txt', '测试数据的数据列表路径') | ||
add_arg('label_list_path', str, 'dataset/label_list.txt', '标签列表路径') | ||
add_arg('model_path', str, 'models/model.pth', '模型保存的路径') | ||
args = parser.parse_args() | ||
|
||
|
||
def evaluate(): | ||
test_dataset = CustomDataset(args.test_list_path, model='eval') | ||
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, num_workers=args.num_workers) | ||
# 获取分类标签 | ||
with open(args.label_list_path, 'r', encoding='utf-8') as f: | ||
lines = f.readlines() | ||
class_labels = [l.replace('\n', '') for l in lines] | ||
# 获取模型 | ||
device = torch.device("cuda") | ||
model = EcapaTdnn(num_classes=args.num_classes) | ||
model.to(device) | ||
model.load_state_dict(torch.load(args.model_path)) | ||
model.eval() | ||
|
||
accuracies, preds, labels = [], [], [] | ||
for batch_id, (spec_mag, label) in enumerate(test_loader): | ||
spec_mag = spec_mag.to(device) | ||
label = label.numpy() | ||
output = model(spec_mag) | ||
output = output.data.cpu().numpy() | ||
pred = np.argmax(output, axis=1) | ||
preds.extend(pred.tolist()) | ||
labels.extend(label.tolist()) | ||
acc = np.mean((pred == label).astype(int)) | ||
accuracies.append(acc.item()) | ||
acc = float(sum(accuracies) / len(accuracies)) | ||
cm = confusion_matrix(labels, preds) | ||
FP = cm.sum(axis=0) - np.diag(cm) | ||
FN = cm.sum(axis=1) - np.diag(cm) | ||
TP = np.diag(cm) | ||
TN = cm.sum() - (FP + FN + TP) | ||
# 精确率 | ||
precision = TP / (TP + FP + 1e-6) | ||
# 召回率 | ||
recall = TP / (TP + FN + 1e-6) | ||
print('分类准确率: {:.4f}, 平均精确率: {:.4f}, 平均召回率: {:.4f}'.format(acc, np.mean(precision), np.mean(recall))) | ||
plot_confusion_matrix(cm=cm, save_path='log/混淆矩阵_eval.png', class_labels=class_labels) | ||
|
||
|
||
if __name__ == '__main__': | ||
print_arguments(args) | ||
evaluate() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,47 @@ | ||
import argparse | ||
import functools | ||
|
||
import librosa | ||
import numpy as np | ||
import torch | ||
|
||
# 加载模型 | ||
model_path = 'models/resnet34.pth' | ||
from utils.ecapa_tdnn import EcapaTdnn | ||
from utils.reader import load_audio | ||
from utils.utility import add_arguments | ||
|
||
parser = argparse.ArgumentParser(description=__doc__) | ||
add_arg = functools.partial(add_arguments, argparser=parser) | ||
add_arg('audio_path', str, 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav', '图片路径') | ||
add_arg('num_classes', int, 10, '分类的类别数量') | ||
add_arg('label_list_path', str, 'dataset/label_list.txt', '标签列表路径') | ||
add_arg('model_path', str, 'models/model.pth', '模型保存的路径') | ||
args = parser.parse_args() | ||
|
||
|
||
# 获取分类标签 | ||
with open(args.label_list_path, 'r', encoding='utf-8') as f: | ||
lines = f.readlines() | ||
class_labels = [l.replace('\n', '') for l in lines] | ||
# 获取模型 | ||
device = torch.device("cuda") | ||
model = torch.jit.load(model_path) | ||
model = EcapaTdnn(num_classes=args.num_classes) | ||
model.to(device) | ||
model.load_state_dict(torch.load(args.model_path)) | ||
model.eval() | ||
|
||
|
||
# 读取音频数据 | ||
def load_data(data_path): | ||
# 读取音频 | ||
wav, sr = librosa.load(data_path, sr=16000) | ||
spec_mag = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256).astype(np.float32) | ||
mean = np.mean(spec_mag, 0, keepdims=True) | ||
std = np.std(spec_mag, 0, keepdims=True) | ||
spec_mag = (spec_mag - mean) / (std + 1e-5) | ||
spec_mag = spec_mag[np.newaxis, np.newaxis, :] | ||
spec_mag = spec_mag.astype('float32') | ||
return spec_mag | ||
|
||
|
||
def infer(audio_path): | ||
data = load_data(audio_path) | ||
def infer(): | ||
data = load_audio(args.audio_path, mode='infer') | ||
data = data[np.newaxis, :] | ||
data = torch.tensor(data, dtype=torch.float32, device=device) | ||
# 执行预测 | ||
output = model(data) | ||
result = torch.nn.functional.softmax(output) | ||
result = torch.nn.functional.softmax(output, dim=-1) | ||
result = result.data.cpu().numpy() | ||
print(result) | ||
# 显示图片并输出结果最大的label | ||
lab = np.argsort(result)[0][-1] | ||
return lab | ||
print(f'音频:{args.audio_path} 的预测结果标签为:{class_labels[lab]}') | ||
|
||
|
||
if __name__ == '__main__': | ||
# 要预测的音频文件 | ||
path = 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav' | ||
label = infer(path) | ||
print('音频:%s 的预测结果标签为:%d' % (path, label)) | ||
infer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.