Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
shaojinding committed May 7, 2020
1 parent 5c3bbb1 commit 3200261
Show file tree
Hide file tree
Showing 45 changed files with 3,255 additions and 1 deletion.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
logs/**
logs_search/**
models/__pycache__/**
__pycache__/**
config/__pycache__/**
data_objects/__pycache__/**
logs_scratch/**
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions .idea/AutoSpeech.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/libraries/R_User_Library.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

98 changes: 97 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,98 @@
# AutoSpeaker: Neural Architecture Search for Speaker Recognition
Source code is coming soon.

Code for this paper [AutoSpeaker: Neural Architecture Search for Speaker Recognition](TBD)

Shaojin Ding*, Tianlong Chen*, Xinyu Gong, Weiwei Zha, Zhangyang Wang

## Overview
Speaker recognition systems based on Convolutional Neural Networks (CNNs) are often built with off-the-shelf backbones such as VGG-Net or ResNet. However, these backbones were originally proposed for image classification, and therefore may not be naturally fit for speaker recognition. Due to the prohibitive complexity of manually exploring the design space, we propose the first neural architecture search approach approach for the speaker recognition tasks, named as AutoSpeech. Our algorithm first identifies the optimal operation combination in a neural cell and then derives a CNN model by stacking the neural cell for multiple times. The final speaker recognition model can be obtained by training the derived CNN model through the standard scheme. To evaluate the proposed approach, we conduct experiments on both speaker identification and speaker verification tasks using the VoxCeleb1 dataset. Results demonstrate that the derived CNN architectures from the proposed approach significantly outperform current speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 back-bones, while enjoying lower model complexity.

##

## Quick start
### Requirements
* Python 3.7

* Pytorch>=1.0: `pip install torch torchvision`

* Other dependencies: `pip install -r requirements`

### Dataset
[VoxCeleb1](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html): You will need `DevA-DevD` and `Test` parts. Additionally, you will need original files: `vox1_meta.csv`, `iden_split.txt`, and `veri_test.txt` from official website.

The data should be organized as:
* VoxCeleb1
* wav
* vox1_meta.csv
* iden_split.txt
* veri_test.txt

### Running the code
* data preprocess:

`python data_preprocess.py /path/to/VoxCeleb1`

* Training and evaluating ResNet-18, ResNet-34 baselines:

`python train_baseline.py --cfg exps/baseline/resnet18.yaml`

`python train_baseline.py --cfg exps/baseline/resnet34.yaml`

You need to modify the `DATA_DIR` field in `.yaml` file.

* Architecture search:

`python search.py --cfg exps/search.yaml`

You need to modify the `DATA_DIR` field in `.yaml` file.

* Training from scratch:

`python train.py --cfg exps/scratch/scratch.yaml --text_arch GENOTYPE`

You need to modify the `DATA_DIR` field in `.yaml` file.

`GENOTYPE` is the search architecture object. For example, the `GENOTYPE` of the architecture report in the paper is:

`"Genotype(normal=[('dil_conv_5x5', 1), ('dil_conv_3x3', 0), ('dil_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('dil_conv_3x3', 2), ('max_pool_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('dil_conv_5x5', 3), ('dil_conv_3x3', 2), ('dil_conv_5x5', 4), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6))"`

* Evaluation:

* Identification

`python evaluate_identification.py --cfg exps/scratch/scratch.yaml --load_path /path/to/the/trained/model`

* Verification

`python evaluate_verification.py --cfg exps/scratch/scratch.yaml --load_path /path/to/the/trained/model`


### Visualization

normal cell | reduction cell
<p align="center">
<img src="figures/searched_arch_normal.png" alt="progress_convolutional_normal" width="45%">
<img src="figures/searched_arch_reduce.png" alt="progress_convolutional_reduce" width="45%">
</p>

## Results

Our proposed approach outperforms speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 backbones. The detailed comparison can be found in our paper.

| Method | Top-1 | EER | Parameters |
| :------------: | :---: | :---: | :---: |
| VGG-M | 80.50 | 10.2 | 67M |
| ResNet-18 | 79.48 | 8.17 | 12M |
| ResNet-34 | 81.34 | 4.64 | 22M |
| Proposed | **87.66** | **1.45** | **18M** |


## Citation

If you use this code for your research, please cite our paper.

```
​```
TBD
​```
```
22 changes: 22 additions & 0 deletions architect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch


def _concat(xs):
return torch.cat([x.view(-1) for x in xs])


class Architect(object):

def __init__(self, model, cfg):
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=cfg.TRAIN.ARCH_LR, betas=(0.5, 0.999), weight_decay=cfg.TRAIN.ARCH_WD)

def step(self, input_valid, target_valid):
self.optimizer.zero_grad()
self._backward_step(input_valid, target_valid)
self.optimizer.step()

def _backward_step(self, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid)
loss.backward()
2 changes: 2 additions & 0 deletions config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .default import _C as cfg
from .default import update_config
69 changes: 69 additions & 0 deletions config/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from yacs.config import CfgNode as CN


_C = CN()

_C.PRINT_FREQ = 20
_C.VAL_FREQ = 20

# Cudnn related params
_C.CUDNN = CN()
_C.CUDNN.BENCHMARK = True
_C.CUDNN.DETERMINISTIC = False
_C.CUDNN.ENABLED = True

# seed
_C.SEED = 3

# common params for NETWORK
_C.MODEL = CN()
_C.MODEL.NAME = 'foo_net'
_C.MODEL.NUM_CLASSES = 500
_C.MODEL.LAYERS = 8
_C.MODEL.INIT_CHANNELS = 16
_C.MODEL.DROP_PATH_PROB = 0.2
_C.MODEL.PRETRAINED = False

# DATASET related params
_C.DATASET = CN()
_C.DATASET.DATA_DIR = ''
_C.DATASET.DATASET = ''
_C.DATASET.TEST_DATA_DIR = ''
_C.DATASET.TEST_DATASET = ''
_C.DATASET.NUM_WORKERS = 0
_C.DATASET.PARTIAL_N_FRAMES = 32
_C.DATASET.FEATURE_DIM = 40


# train
_C.TRAIN = CN()

_C.TRAIN.BATCH_SIZE = 32
_C.TRAIN.LR = 0.1
_C.TRAIN.LR_MIN = 0.001
_C.TRAIN.WD = 0.0
_C.TRAIN.BETA1 = 0.9
_C.TRAIN.BETA2 = 0.999

_C.TRAIN.ARCH_LR = 0.1
_C.TRAIN.ARCH_WD = 0.0
_C.TRAIN.ARCH_BETA1 = 0.9
_C.TRAIN.ARCH_BETA2 = 0.999

_C.TRAIN.DROPPATH_PROB = 0.2

_C.TRAIN.BEGIN_EPOCH = 0
_C.TRAIN.END_EPOCH = 140



def update_config(cfg, args):
cfg.defrost()
cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)

cfg.freeze()
74 changes: 74 additions & 0 deletions data_objects/DeepSpeakerDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import print_function


import numpy as np
import torch.utils.data as data
from data_objects.speaker import Speaker
from torchvision import transforms as T
from data_objects.transforms import Normalize, TimeReverse, generate_test_sequence


def find_classes(speakers):
classes = list(set([speaker.name for speaker in speakers]))
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx


class DeepSpeakerDataset(data.Dataset):

def __init__(self, data_dir, partial_n_frames, partition=None, is_test=False):
super(DeepSpeakerDataset, self).__init__()
self.data_dir = data_dir
self.root = data_dir.joinpath('feature')
self.partition = partition
self.partial_n_frames = partial_n_frames
self.is_test = is_test

speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
if len(speaker_dirs) == 0:
raise Exception("No speakers found. Make sure you are pointing to the directory "
"containing all preprocessed speaker directories.")
self.speakers = [Speaker(speaker_dir, self.partition) for speaker_dir in speaker_dirs]

classes, class_to_idx = find_classes(self.speakers)
sources = []
for speaker in self.speakers:
sources.extend(speaker.sources)
self.features = []
for source in sources:
item = (source[0].joinpath(source[1]), class_to_idx[source[2]])
self.features.append(item)
mean = np.load(self.data_dir.joinpath('mean.npy'))
std = np.load(self.data_dir.joinpath('std.npy'))
self.transform = T.Compose([
Normalize(mean, std),
TimeReverse(),
])

def load_feature(self, feature_path, speaker_id):
feature = np.load(feature_path)
if self.is_test:
test_sequence = generate_test_sequence(feature, self.partial_n_frames)
return test_sequence, speaker_id
else:
if feature.shape[0] <= self.partial_n_frames:
start = 0
while feature.shape[0] < self.partial_n_frames:
feature = np.repeat(feature, 2, axis=0)
else:
start = np.random.randint(0, feature.shape[0] - self.partial_n_frames)
end = start + self.partial_n_frames
return feature[start:end], speaker_id

def __getitem__(self, index):
feature_path, speaker_id = self.features[index]
feature, speaker_id = self.load_feature(feature_path, speaker_id)

if self.transform is not None:
feature = self.transform(feature)
return feature, speaker_id

def __len__(self):
return len(self.features)

Loading

0 comments on commit 3200261

Please sign in to comment.