Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
golbin committed Jan 5, 2018
0 parents commit e391394
Show file tree
Hide file tree
Showing 17 changed files with 1,134 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.DS_Store
__pycache__
.cache
.idea
output
datasets
74 changes: 74 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# WaveNet

Yet another WaveNet implementation in PyTorch.

The purpose of this implementation is Well-structured, reusable and easily understandable.

- [WaveNet Paper](https://arxiv.org/pdf/1609.03499.pdf)
- [WaveNet: A Generative Model for Raw Audio](https://deepmind.com/blog/wavenet-generative-model-raw-audio/)

## Prerequisites

- System
- Linux or macOS
- CPU or (NVIDIA GPU + CUDA CuDNN)
- It can run on Single CPU/GPU or Multi GPUs.
- Python 3

- Libraries
- PyTorch >= 0.3.0
- librosa >= 0.5.1

## Training

```bash
python train.py \
--data_dir=./test/data \
--output_dir=./outputs
```

Use `python train.py --help` to see more options.

## Generating

It's just for testing. You need to modify for real world.

```bash
python generate.py \
--model=./outputs/model \
--seed=./test/data/helloworld.wav \
--out=./output/helloworld.wav
```

Use `python generate.py --help` to see more options.

## File structures

`modules.py` and `model.py` is main implementations.

- wavenet
- `config.py` : Training options
- `networks.py` : The neural network architecture of WaveNet
- `model.py` : Calculate loss and optimizing
- utils
- `data.py` : Utilities for loading data
- `logger.py` : Utilities for logging
- test
- Some tests for check if it's correct model like casual, dilated..
- `train.py` : A script for WaveNet training
- `generate.py` : A script for generating with pre-trained model

# TODO

- [ ] Faster generating
- [ ] Parallel WaveNet
- [ ] General Generator

## References

- https://github.com/ibab/tensorflow-wavenet
- https://qiita.com/MasaEguchi/items/cd5f7e9735a120f27e2a
- https://github.com/musyoku/wavenet/issues/4
- https://github.com/vincentherrmann/pytorch-wavenet
- http://sergeiturukin.com/2017/03/02/wavenet.html

92 changes: 92 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
A script for WaveNet training
"""
import torch
import librosa
import datetime
import numpy as np

import wavenet.config as config
from wavenet.model import WaveNet
import wavenet.utils.data as utils


class Generator:
def __init__(self, args):
self.args = args

self.wavenet = WaveNet(args.layer_size, args.stack_size,
args.in_channels, args.res_channels)

@staticmethod
def _variable(data):
tensor = torch.from_numpy(data).float()

if torch.cuda.is_available():
return torch.autograd.Variable(tensor.cuda())
else:
return torch.autograd.Variable(tensor)

def _make_seed(self, audio):
audio = np.pad([audio], [[0, 0], [self.wavenet.receptive_fields, 0], [0, 0]], 'constant')

if self.args.sample_size:
seed = audio[:, :self.args.sample_size, :]
else:
seed = audio[:, :self.wavenet.receptive_fields*2, :]

return seed

def _get_seed_from_audio(self, filepath):
audio = utils.load_audio(filepath, self.args.sample_rate)
audio_length = len(audio)

audio = utils.mu_law_encode(audio, self.args.in_channels)
audio = utils.one_hot_encode(audio, self.args.in_channels)

seed = self._make_seed(audio)

return self._variable(seed), audio_length

def _save_to_audio_file(self, data):
data = data[0].cpu().data.numpy()
data = utils.one_hot_decode(data, axis=1)
audio = utils.mu_law_decode(data, self.args.in_channels)

librosa.output.write_wav(self.args.out, audio, self.args.sample_rate)
print('Saved wav file at {}'.format(self.args.out))

return librosa.get_duration(y=audio, sr=self.args.sample_rate)

def generate(self):
outputs = []
inputs, audio_length = self._get_seed_from_audio(self.args.seed)

while True:
new = self.wavenet.generate(inputs)

outputs = torch.cat((outputs, new), dim=1) if len(outputs) else new

print('{0}/{1} samples are generated.'.format(len(outputs[0]), audio_length))

if len(outputs[0]) >= audio_length:
break

inputs = torch.cat((inputs[:, :-len(new[0]), :], new), dim=1)

outputs = outputs[:, :audio_length, :]

return self._save_to_audio_file(outputs)


if __name__ == '__main__':
args = config.parse_args(is_training=False)

generator = Generator(args)

start_time = datetime.datetime.now()

duration = generator.generate()

print('Generate {0} seconds took {1}'.format(duration, datetime.datetime.now() - start_time))

63 changes: 63 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
A script for WaveNet training
"""
import os

import nsml
from line_notify import LineNotify

import wavenet.config as config
from wavenet.model import WaveNet
from wavenet.utils.data import DataLoader


notify = LineNotify("LrEZ94o3PAH4kZ82JHSkfjTQGbOsc1cY2iAKWYvYZr6", name="NSML.WaveNet")


class Trainer:
def __init__(self, args):
self.args = args

self.wavenet = WaveNet(args.layer_size, args.stack_size,
args.in_channels, args.res_channels)

self.data_loader = DataLoader(args.data_dir, self.wavenet.receptive_fields,
args.sample_size, args.sample_rate, args.in_channels)

def run(self):
total_steps = 0

for dataset in self.data_loader:
for inputs, targets in dataset:
loss = self.wavenet.train(inputs, targets)
total_steps += 1

print('[{0}/{1}] loss: {2}'.format(total_steps, args.num_steps, loss))
notify.send('[{0}/{1}] loss: {2}'.format(total_steps, args.num_steps, loss))

if total_steps > args.num_steps:
break

self.wavenet.save(args.model_dir)

notify.send('Training Finished!!')


def prepare_output_dir(args):
args.log_dir = os.path.join(args.output_dir, 'log')
args.model_dir = os.path.join(args.output_dir, 'model')
args.test_output_dir = os.path.join(args.output_dir, 'test')

os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.test_output_dir, exist_ok=True)


if __name__ == '__main__':
args = config.parse_args()

prepare_output_dir(args)

trainer = Trainer(args)

trainer.run()
13 changes: 13 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#nsml: floydhub/pytorch:0.3.0-gpu.cuda8cudnn6-py3.17

from distutils.core import setup

setup(
name='WaveNet example for NSML',
version='0.1',
description='WaveNet for NSML',
install_requires=[
'librosa',
'line_notify'
]
)
Binary file added test/data/helloworld.wav
Binary file not shown.
92 changes: 92 additions & 0 deletions test/test_causal_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Test Dilated Causal Convolution
"""

import os
import sys

import torch
import pytest
import numpy as np

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from wavenet.networks import CausalConv1d, DilatedCausalConv1d


CAUSAL_RESULT = [
[[[18, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94]]]
]

DILATED_CAUSAL_RESULT = [
[[[56, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184]]],
[[[144, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352]]],
[[[368, 416, 448, 480, 512, 544, 576, 608, 640]]],
[[[1008]]]
]


def causal_conv(data, in_channels, out_channels, print_result=True):
conv = CausalConv1d(in_channels, out_channels)
conv.init_weights_for_test()

output = conv(data)

print('Causal convolution ---')
if print_result:
print(' {0}'.format(output.data.numpy().astype(int)))

return output


def dilated_causal_conv(step, data, channels, dilation=1, print_result=True):
conv = DilatedCausalConv1d(channels, dilation=dilation)
conv.init_weights_for_test()

output = conv(data)

print('{0} step is OK: dilation={1}, size={2}'.format(step, dilation, output.shape))
if print_result:
print(' {0}'.format(output.data.numpy().astype(int)))

return output


@pytest.fixture
def generate_x():
"""Test normal convolution 1d"""
x = np.arange(1, 33, dtype=np.float32)
x = np.reshape(x, [1, 2, 16]) # [batch, channel, timestep]
x = torch.autograd.Variable(torch.from_numpy(x))

print('Input size={0}'.format(x.shape))
print(x.data.numpy().astype(int))
print('-'*80)

return x


@pytest.fixture
def test_causal_conv(generate_x):
"""Test normal convolution 1d"""
result = causal_conv(generate_x, 2, 1)

np.testing.assert_array_equal(
result.data.numpy().astype(int),
CAUSAL_RESULT[0]
)

return result


def test_dilated_causal_conv(test_causal_conv):
"""Test dilated causal convolution : dilation=[1, 2, 4, 8]"""
result = test_causal_conv

for i in range(0, 4):
result = dilated_causal_conv(i+1, result, 1, dilation=2**i)

np.testing.assert_array_equal(
result.data.numpy().astype(int),
DILATED_CAUSAL_RESULT[i]
)

50 changes: 50 additions & 0 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Test mu-law encoding and decoding
"""

import os
import sys

import torch

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from wavenet.utils.data import DataLoader


RECEPTIVE_FIELDS = 1000
SAMPLE_SIZE = 2000
SAMPLE_RATE = 8000
IN_CHANNELS = 256
TEST_AUDIO_DIR = os.path.join(os.path.dirname(__file__), 'data')


def test_data_loader():
data_loader = DataLoader(TEST_AUDIO_DIR,
RECEPTIVE_FIELDS, SAMPLE_SIZE, SAMPLE_RATE, IN_CHANNELS,
shuffle=False)

dataset_size = []

for dataset in data_loader:
input_size = []
target_size = []

for i, t in dataset:
input_size.append(i.shape)
target_size.append(t.shape)

dataset_size.append([input_size, target_size])

assert dataset_size[0][0][0] == torch.Size([1, 2000, 256])
assert dataset_size[0][1][0] == torch.Size([1, 1000])
assert dataset_size[0][0][-1] == torch.Size([1, 1839, 256])
assert dataset_size[0][1][-1] == torch.Size([1, 839])

assert dataset_size[1][0][0] == torch.Size([1, 2000, 256])
assert dataset_size[1][1][0] == torch.Size([1, 1000])
assert dataset_size[1][0][-1] == torch.Size([1, 1762, 256])
assert dataset_size[1][1][-1] == torch.Size([1, 762])

assert len(dataset_size[0][0]) == 8
assert len(dataset_size[1][0]) == 8

Loading

0 comments on commit e391394

Please sign in to comment.