-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
logs/** | ||
logs_search/** | ||
models/__pycache__/** | ||
__pycache__/** | ||
config/__pycache__/** | ||
data_objects/__pycache__/** | ||
logs_scratch/** |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,98 @@ | ||
# AutoSpeaker: Neural Architecture Search for Speaker Recognition | ||
Source code is coming soon. | ||
|
||
Code for this paper [AutoSpeaker: Neural Architecture Search for Speaker Recognition](TBD) | ||
|
||
Shaojin Ding*, Tianlong Chen*, Xinyu Gong, Weiwei Zha, Zhangyang Wang | ||
|
||
## Overview | ||
Speaker recognition systems based on Convolutional Neural Networks (CNNs) are often built with off-the-shelf backbones such as VGG-Net or ResNet. However, these backbones were originally proposed for image classification, and therefore may not be naturally fit for speaker recognition. Due to the prohibitive complexity of manually exploring the design space, we propose the first neural architecture search approach approach for the speaker recognition tasks, named as AutoSpeech. Our algorithm first identifies the optimal operation combination in a neural cell and then derives a CNN model by stacking the neural cell for multiple times. The final speaker recognition model can be obtained by training the derived CNN model through the standard scheme. To evaluate the proposed approach, we conduct experiments on both speaker identification and speaker verification tasks using the VoxCeleb1 dataset. Results demonstrate that the derived CNN architectures from the proposed approach significantly outperform current speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 back-bones, while enjoying lower model complexity. | ||
|
||
## | ||
|
||
## Quick start | ||
### Requirements | ||
* Python 3.7 | ||
|
||
* Pytorch>=1.0: `pip install torch torchvision` | ||
|
||
* Other dependencies: `pip install -r requirements` | ||
|
||
### Dataset | ||
[VoxCeleb1](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html): You will need `DevA-DevD` and `Test` parts. Additionally, you will need original files: `vox1_meta.csv`, `iden_split.txt`, and `veri_test.txt` from official website. | ||
|
||
The data should be organized as: | ||
* VoxCeleb1 | ||
* wav | ||
* vox1_meta.csv | ||
* iden_split.txt | ||
* veri_test.txt | ||
|
||
### Running the code | ||
* data preprocess: | ||
|
||
`python data_preprocess.py /path/to/VoxCeleb1` | ||
|
||
* Training and evaluating ResNet-18, ResNet-34 baselines: | ||
|
||
`python train_baseline.py --cfg exps/baseline/resnet18.yaml` | ||
|
||
`python train_baseline.py --cfg exps/baseline/resnet34.yaml` | ||
|
||
You need to modify the `DATA_DIR` field in `.yaml` file. | ||
|
||
* Architecture search: | ||
|
||
`python search.py --cfg exps/search.yaml` | ||
|
||
You need to modify the `DATA_DIR` field in `.yaml` file. | ||
|
||
* Training from scratch: | ||
|
||
`python train.py --cfg exps/scratch/scratch.yaml --text_arch GENOTYPE` | ||
|
||
You need to modify the `DATA_DIR` field in `.yaml` file. | ||
|
||
`GENOTYPE` is the search architecture object. For example, the `GENOTYPE` of the architecture report in the paper is: | ||
|
||
`"Genotype(normal=[('dil_conv_5x5', 1), ('dil_conv_3x3', 0), ('dil_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('dil_conv_3x3', 2), ('max_pool_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('dil_conv_5x5', 3), ('dil_conv_3x3', 2), ('dil_conv_5x5', 4), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6))"` | ||
|
||
* Evaluation: | ||
|
||
* Identification | ||
|
||
`python evaluate_identification.py --cfg exps/scratch/scratch.yaml --load_path /path/to/the/trained/model` | ||
|
||
* Verification | ||
|
||
`python evaluate_verification.py --cfg exps/scratch/scratch.yaml --load_path /path/to/the/trained/model` | ||
|
||
|
||
### Visualization | ||
|
||
normal cell | reduction cell | ||
<p align="center"> | ||
<img src="figures/searched_arch_normal.png" alt="progress_convolutional_normal" width="45%"> | ||
<img src="figures/searched_arch_reduce.png" alt="progress_convolutional_reduce" width="45%"> | ||
</p> | ||
|
||
## Results | ||
|
||
Our proposed approach outperforms speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 backbones. The detailed comparison can be found in our paper. | ||
|
||
| Method | Top-1 | EER | Parameters | | ||
| :------------: | :---: | :---: | :---: | | ||
| VGG-M | 80.50 | 10.2 | 67M | | ||
| ResNet-18 | 79.48 | 8.17 | 12M | | ||
| ResNet-34 | 81.34 | 4.64 | 22M | | ||
| Proposed | **87.66** | **1.45** | **18M** | | ||
|
||
|
||
## Citation | ||
|
||
If you use this code for your research, please cite our paper. | ||
|
||
``` | ||
``` | ||
TBD | ||
``` | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
|
||
|
||
def _concat(xs): | ||
return torch.cat([x.view(-1) for x in xs]) | ||
|
||
|
||
class Architect(object): | ||
|
||
def __init__(self, model, cfg): | ||
self.model = model | ||
self.optimizer = torch.optim.Adam(self.model.arch_parameters(), | ||
lr=cfg.TRAIN.ARCH_LR, betas=(0.5, 0.999), weight_decay=cfg.TRAIN.ARCH_WD) | ||
|
||
def step(self, input_valid, target_valid): | ||
self.optimizer.zero_grad() | ||
self._backward_step(input_valid, target_valid) | ||
self.optimizer.step() | ||
|
||
def _backward_step(self, input_valid, target_valid): | ||
loss = self.model._loss(input_valid, target_valid) | ||
loss.backward() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .default import _C as cfg | ||
from .default import update_config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from yacs.config import CfgNode as CN | ||
|
||
|
||
_C = CN() | ||
|
||
_C.PRINT_FREQ = 20 | ||
_C.VAL_FREQ = 20 | ||
|
||
# Cudnn related params | ||
_C.CUDNN = CN() | ||
_C.CUDNN.BENCHMARK = True | ||
_C.CUDNN.DETERMINISTIC = False | ||
_C.CUDNN.ENABLED = True | ||
|
||
# seed | ||
_C.SEED = 3 | ||
|
||
# common params for NETWORK | ||
_C.MODEL = CN() | ||
_C.MODEL.NAME = 'foo_net' | ||
_C.MODEL.NUM_CLASSES = 500 | ||
_C.MODEL.LAYERS = 8 | ||
_C.MODEL.INIT_CHANNELS = 16 | ||
_C.MODEL.DROP_PATH_PROB = 0.2 | ||
_C.MODEL.PRETRAINED = False | ||
|
||
# DATASET related params | ||
_C.DATASET = CN() | ||
_C.DATASET.DATA_DIR = '' | ||
_C.DATASET.DATASET = '' | ||
_C.DATASET.TEST_DATA_DIR = '' | ||
_C.DATASET.TEST_DATASET = '' | ||
_C.DATASET.NUM_WORKERS = 0 | ||
_C.DATASET.PARTIAL_N_FRAMES = 32 | ||
_C.DATASET.FEATURE_DIM = 40 | ||
|
||
|
||
# train | ||
_C.TRAIN = CN() | ||
|
||
_C.TRAIN.BATCH_SIZE = 32 | ||
_C.TRAIN.LR = 0.1 | ||
_C.TRAIN.LR_MIN = 0.001 | ||
_C.TRAIN.WD = 0.0 | ||
_C.TRAIN.BETA1 = 0.9 | ||
_C.TRAIN.BETA2 = 0.999 | ||
|
||
_C.TRAIN.ARCH_LR = 0.1 | ||
_C.TRAIN.ARCH_WD = 0.0 | ||
_C.TRAIN.ARCH_BETA1 = 0.9 | ||
_C.TRAIN.ARCH_BETA2 = 0.999 | ||
|
||
_C.TRAIN.DROPPATH_PROB = 0.2 | ||
|
||
_C.TRAIN.BEGIN_EPOCH = 0 | ||
_C.TRAIN.END_EPOCH = 140 | ||
|
||
|
||
|
||
def update_config(cfg, args): | ||
cfg.defrost() | ||
cfg.merge_from_file(args.cfg) | ||
cfg.merge_from_list(args.opts) | ||
|
||
cfg.freeze() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from __future__ import print_function | ||
|
||
|
||
import numpy as np | ||
import torch.utils.data as data | ||
from data_objects.speaker import Speaker | ||
from torchvision import transforms as T | ||
from data_objects.transforms import Normalize, TimeReverse, generate_test_sequence | ||
|
||
|
||
def find_classes(speakers): | ||
classes = list(set([speaker.name for speaker in speakers])) | ||
classes.sort() | ||
class_to_idx = {classes[i]: i for i in range(len(classes))} | ||
return classes, class_to_idx | ||
|
||
|
||
class DeepSpeakerDataset(data.Dataset): | ||
|
||
def __init__(self, data_dir, partial_n_frames, partition=None, is_test=False): | ||
super(DeepSpeakerDataset, self).__init__() | ||
self.data_dir = data_dir | ||
self.root = data_dir.joinpath('feature') | ||
self.partition = partition | ||
self.partial_n_frames = partial_n_frames | ||
self.is_test = is_test | ||
|
||
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] | ||
if len(speaker_dirs) == 0: | ||
raise Exception("No speakers found. Make sure you are pointing to the directory " | ||
"containing all preprocessed speaker directories.") | ||
self.speakers = [Speaker(speaker_dir, self.partition) for speaker_dir in speaker_dirs] | ||
|
||
classes, class_to_idx = find_classes(self.speakers) | ||
sources = [] | ||
for speaker in self.speakers: | ||
sources.extend(speaker.sources) | ||
self.features = [] | ||
for source in sources: | ||
item = (source[0].joinpath(source[1]), class_to_idx[source[2]]) | ||
self.features.append(item) | ||
mean = np.load(self.data_dir.joinpath('mean.npy')) | ||
std = np.load(self.data_dir.joinpath('std.npy')) | ||
self.transform = T.Compose([ | ||
Normalize(mean, std), | ||
TimeReverse(), | ||
]) | ||
|
||
def load_feature(self, feature_path, speaker_id): | ||
feature = np.load(feature_path) | ||
if self.is_test: | ||
test_sequence = generate_test_sequence(feature, self.partial_n_frames) | ||
return test_sequence, speaker_id | ||
else: | ||
if feature.shape[0] <= self.partial_n_frames: | ||
start = 0 | ||
while feature.shape[0] < self.partial_n_frames: | ||
feature = np.repeat(feature, 2, axis=0) | ||
else: | ||
start = np.random.randint(0, feature.shape[0] - self.partial_n_frames) | ||
end = start + self.partial_n_frames | ||
return feature[start:end], speaker_id | ||
|
||
def __getitem__(self, index): | ||
feature_path, speaker_id = self.features[index] | ||
feature, speaker_id = self.load_feature(feature_path, speaker_id) | ||
|
||
if self.transform is not None: | ||
feature = self.transform(feature) | ||
return feature, speaker_id | ||
|
||
def __len__(self): | ||
return len(self.features) | ||
|