forked from golbin/WaveNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit e391394
Showing
17 changed files
with
1,134 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.DS_Store | ||
__pycache__ | ||
.cache | ||
.idea | ||
output | ||
datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# WaveNet | ||
|
||
Yet another WaveNet implementation in PyTorch. | ||
|
||
The purpose of this implementation is Well-structured, reusable and easily understandable. | ||
|
||
- [WaveNet Paper](https://arxiv.org/pdf/1609.03499.pdf) | ||
- [WaveNet: A Generative Model for Raw Audio](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) | ||
|
||
## Prerequisites | ||
|
||
- System | ||
- Linux or macOS | ||
- CPU or (NVIDIA GPU + CUDA CuDNN) | ||
- It can run on Single CPU/GPU or Multi GPUs. | ||
- Python 3 | ||
|
||
- Libraries | ||
- PyTorch >= 0.3.0 | ||
- librosa >= 0.5.1 | ||
|
||
## Training | ||
|
||
```bash | ||
python train.py \ | ||
--data_dir=./test/data \ | ||
--output_dir=./outputs | ||
``` | ||
|
||
Use `python train.py --help` to see more options. | ||
|
||
## Generating | ||
|
||
It's just for testing. You need to modify for real world. | ||
|
||
```bash | ||
python generate.py \ | ||
--model=./outputs/model \ | ||
--seed=./test/data/helloworld.wav \ | ||
--out=./output/helloworld.wav | ||
``` | ||
|
||
Use `python generate.py --help` to see more options. | ||
|
||
## File structures | ||
|
||
`modules.py` and `model.py` is main implementations. | ||
|
||
- wavenet | ||
- `config.py` : Training options | ||
- `networks.py` : The neural network architecture of WaveNet | ||
- `model.py` : Calculate loss and optimizing | ||
- utils | ||
- `data.py` : Utilities for loading data | ||
- `logger.py` : Utilities for logging | ||
- test | ||
- Some tests for check if it's correct model like casual, dilated.. | ||
- `train.py` : A script for WaveNet training | ||
- `generate.py` : A script for generating with pre-trained model | ||
|
||
# TODO | ||
|
||
- [ ] Faster generating | ||
- [ ] Parallel WaveNet | ||
- [ ] General Generator | ||
|
||
## References | ||
|
||
- https://github.com/ibab/tensorflow-wavenet | ||
- https://qiita.com/MasaEguchi/items/cd5f7e9735a120f27e2a | ||
- https://github.com/musyoku/wavenet/issues/4 | ||
- https://github.com/vincentherrmann/pytorch-wavenet | ||
- http://sergeiturukin.com/2017/03/02/wavenet.html | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
A script for WaveNet training | ||
""" | ||
import torch | ||
import librosa | ||
import datetime | ||
import numpy as np | ||
|
||
import wavenet.config as config | ||
from wavenet.model import WaveNet | ||
import wavenet.utils.data as utils | ||
|
||
|
||
class Generator: | ||
def __init__(self, args): | ||
self.args = args | ||
|
||
self.wavenet = WaveNet(args.layer_size, args.stack_size, | ||
args.in_channels, args.res_channels) | ||
|
||
@staticmethod | ||
def _variable(data): | ||
tensor = torch.from_numpy(data).float() | ||
|
||
if torch.cuda.is_available(): | ||
return torch.autograd.Variable(tensor.cuda()) | ||
else: | ||
return torch.autograd.Variable(tensor) | ||
|
||
def _make_seed(self, audio): | ||
audio = np.pad([audio], [[0, 0], [self.wavenet.receptive_fields, 0], [0, 0]], 'constant') | ||
|
||
if self.args.sample_size: | ||
seed = audio[:, :self.args.sample_size, :] | ||
else: | ||
seed = audio[:, :self.wavenet.receptive_fields*2, :] | ||
|
||
return seed | ||
|
||
def _get_seed_from_audio(self, filepath): | ||
audio = utils.load_audio(filepath, self.args.sample_rate) | ||
audio_length = len(audio) | ||
|
||
audio = utils.mu_law_encode(audio, self.args.in_channels) | ||
audio = utils.one_hot_encode(audio, self.args.in_channels) | ||
|
||
seed = self._make_seed(audio) | ||
|
||
return self._variable(seed), audio_length | ||
|
||
def _save_to_audio_file(self, data): | ||
data = data[0].cpu().data.numpy() | ||
data = utils.one_hot_decode(data, axis=1) | ||
audio = utils.mu_law_decode(data, self.args.in_channels) | ||
|
||
librosa.output.write_wav(self.args.out, audio, self.args.sample_rate) | ||
print('Saved wav file at {}'.format(self.args.out)) | ||
|
||
return librosa.get_duration(y=audio, sr=self.args.sample_rate) | ||
|
||
def generate(self): | ||
outputs = [] | ||
inputs, audio_length = self._get_seed_from_audio(self.args.seed) | ||
|
||
while True: | ||
new = self.wavenet.generate(inputs) | ||
|
||
outputs = torch.cat((outputs, new), dim=1) if len(outputs) else new | ||
|
||
print('{0}/{1} samples are generated.'.format(len(outputs[0]), audio_length)) | ||
|
||
if len(outputs[0]) >= audio_length: | ||
break | ||
|
||
inputs = torch.cat((inputs[:, :-len(new[0]), :], new), dim=1) | ||
|
||
outputs = outputs[:, :audio_length, :] | ||
|
||
return self._save_to_audio_file(outputs) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = config.parse_args(is_training=False) | ||
|
||
generator = Generator(args) | ||
|
||
start_time = datetime.datetime.now() | ||
|
||
duration = generator.generate() | ||
|
||
print('Generate {0} seconds took {1}'.format(duration, datetime.datetime.now() - start_time)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" | ||
A script for WaveNet training | ||
""" | ||
import os | ||
|
||
import nsml | ||
from line_notify import LineNotify | ||
|
||
import wavenet.config as config | ||
from wavenet.model import WaveNet | ||
from wavenet.utils.data import DataLoader | ||
|
||
|
||
notify = LineNotify("LrEZ94o3PAH4kZ82JHSkfjTQGbOsc1cY2iAKWYvYZr6", name="NSML.WaveNet") | ||
|
||
|
||
class Trainer: | ||
def __init__(self, args): | ||
self.args = args | ||
|
||
self.wavenet = WaveNet(args.layer_size, args.stack_size, | ||
args.in_channels, args.res_channels) | ||
|
||
self.data_loader = DataLoader(args.data_dir, self.wavenet.receptive_fields, | ||
args.sample_size, args.sample_rate, args.in_channels) | ||
|
||
def run(self): | ||
total_steps = 0 | ||
|
||
for dataset in self.data_loader: | ||
for inputs, targets in dataset: | ||
loss = self.wavenet.train(inputs, targets) | ||
total_steps += 1 | ||
|
||
print('[{0}/{1}] loss: {2}'.format(total_steps, args.num_steps, loss)) | ||
notify.send('[{0}/{1}] loss: {2}'.format(total_steps, args.num_steps, loss)) | ||
|
||
if total_steps > args.num_steps: | ||
break | ||
|
||
self.wavenet.save(args.model_dir) | ||
|
||
notify.send('Training Finished!!') | ||
|
||
|
||
def prepare_output_dir(args): | ||
args.log_dir = os.path.join(args.output_dir, 'log') | ||
args.model_dir = os.path.join(args.output_dir, 'model') | ||
args.test_output_dir = os.path.join(args.output_dir, 'test') | ||
|
||
os.makedirs(args.log_dir, exist_ok=True) | ||
os.makedirs(args.model_dir, exist_ok=True) | ||
os.makedirs(args.test_output_dir, exist_ok=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = config.parse_args() | ||
|
||
prepare_output_dir(args) | ||
|
||
trainer = Trainer(args) | ||
|
||
trainer.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#nsml: floydhub/pytorch:0.3.0-gpu.cuda8cudnn6-py3.17 | ||
|
||
from distutils.core import setup | ||
|
||
setup( | ||
name='WaveNet example for NSML', | ||
version='0.1', | ||
description='WaveNet for NSML', | ||
install_requires=[ | ||
'librosa', | ||
'line_notify' | ||
] | ||
) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Test Dilated Causal Convolution | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
import torch | ||
import pytest | ||
import numpy as np | ||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | ||
from wavenet.networks import CausalConv1d, DilatedCausalConv1d | ||
|
||
|
||
CAUSAL_RESULT = [ | ||
[[[18, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94]]] | ||
] | ||
|
||
DILATED_CAUSAL_RESULT = [ | ||
[[[56, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184]]], | ||
[[[144, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352]]], | ||
[[[368, 416, 448, 480, 512, 544, 576, 608, 640]]], | ||
[[[1008]]] | ||
] | ||
|
||
|
||
def causal_conv(data, in_channels, out_channels, print_result=True): | ||
conv = CausalConv1d(in_channels, out_channels) | ||
conv.init_weights_for_test() | ||
|
||
output = conv(data) | ||
|
||
print('Causal convolution ---') | ||
if print_result: | ||
print(' {0}'.format(output.data.numpy().astype(int))) | ||
|
||
return output | ||
|
||
|
||
def dilated_causal_conv(step, data, channels, dilation=1, print_result=True): | ||
conv = DilatedCausalConv1d(channels, dilation=dilation) | ||
conv.init_weights_for_test() | ||
|
||
output = conv(data) | ||
|
||
print('{0} step is OK: dilation={1}, size={2}'.format(step, dilation, output.shape)) | ||
if print_result: | ||
print(' {0}'.format(output.data.numpy().astype(int))) | ||
|
||
return output | ||
|
||
|
||
@pytest.fixture | ||
def generate_x(): | ||
"""Test normal convolution 1d""" | ||
x = np.arange(1, 33, dtype=np.float32) | ||
x = np.reshape(x, [1, 2, 16]) # [batch, channel, timestep] | ||
x = torch.autograd.Variable(torch.from_numpy(x)) | ||
|
||
print('Input size={0}'.format(x.shape)) | ||
print(x.data.numpy().astype(int)) | ||
print('-'*80) | ||
|
||
return x | ||
|
||
|
||
@pytest.fixture | ||
def test_causal_conv(generate_x): | ||
"""Test normal convolution 1d""" | ||
result = causal_conv(generate_x, 2, 1) | ||
|
||
np.testing.assert_array_equal( | ||
result.data.numpy().astype(int), | ||
CAUSAL_RESULT[0] | ||
) | ||
|
||
return result | ||
|
||
|
||
def test_dilated_causal_conv(test_causal_conv): | ||
"""Test dilated causal convolution : dilation=[1, 2, 4, 8]""" | ||
result = test_causal_conv | ||
|
||
for i in range(0, 4): | ||
result = dilated_causal_conv(i+1, result, 1, dilation=2**i) | ||
|
||
np.testing.assert_array_equal( | ||
result.data.numpy().astype(int), | ||
DILATED_CAUSAL_RESULT[i] | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
Test mu-law encoding and decoding | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | ||
from wavenet.utils.data import DataLoader | ||
|
||
|
||
RECEPTIVE_FIELDS = 1000 | ||
SAMPLE_SIZE = 2000 | ||
SAMPLE_RATE = 8000 | ||
IN_CHANNELS = 256 | ||
TEST_AUDIO_DIR = os.path.join(os.path.dirname(__file__), 'data') | ||
|
||
|
||
def test_data_loader(): | ||
data_loader = DataLoader(TEST_AUDIO_DIR, | ||
RECEPTIVE_FIELDS, SAMPLE_SIZE, SAMPLE_RATE, IN_CHANNELS, | ||
shuffle=False) | ||
|
||
dataset_size = [] | ||
|
||
for dataset in data_loader: | ||
input_size = [] | ||
target_size = [] | ||
|
||
for i, t in dataset: | ||
input_size.append(i.shape) | ||
target_size.append(t.shape) | ||
|
||
dataset_size.append([input_size, target_size]) | ||
|
||
assert dataset_size[0][0][0] == torch.Size([1, 2000, 256]) | ||
assert dataset_size[0][1][0] == torch.Size([1, 1000]) | ||
assert dataset_size[0][0][-1] == torch.Size([1, 1839, 256]) | ||
assert dataset_size[0][1][-1] == torch.Size([1, 839]) | ||
|
||
assert dataset_size[1][0][0] == torch.Size([1, 2000, 256]) | ||
assert dataset_size[1][1][0] == torch.Size([1, 1000]) | ||
assert dataset_size[1][0][-1] == torch.Size([1, 1762, 256]) | ||
assert dataset_size[1][1][-1] == torch.Size([1, 762]) | ||
|
||
assert len(dataset_size[0][0]) == 8 | ||
assert len(dataset_size[1][0]) == 8 | ||
|
Oops, something went wrong.