Skip to content

Commit 1b1e825

Browse files
authored
Merge pull request #7 from nagyrajmund/GENEA_2020
Update decoding.py to use the new parameter handling
2 parents 8c4025d + 383de9e commit 1b1e825

File tree

9 files changed

+79
-520
lines changed

9 files changed

+79
-520
lines changed

README.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ ________________________________________________________________________________
3030

3131
# How to use this repository?
3232

33+
# 0. Notation
34+
35+
Whenever a parameter is written in caps (such as DATA_DIR), it has to be specified by the user on the command line as a positional argument.
36+
3337
## 1. Obtain raw data
3438

3539
- Clone this repository
@@ -42,8 +46,9 @@ git checkout GENEA_2020
4246
```
4347
- Download a dataset from KTH Box using the link you obtained after singing the license agreement
4448

45-
4649
## 2. Pre-process the data
50+
By default, the model expects the dataset in the `<repository>/dataset/raw` folder, and the processed dataset will be available in the `<repository>/dataset/processed folder`. If your dataset is elsewhere, please provide the correct paths with the `--raw_data_dir` and `--proc_data_dir` command line arguments. You can also use '--help' argument to see more details about the scripts.
51+
4752
```
4853
cd data_processing
4954
@@ -59,23 +64,20 @@ python process_dataset.py
5964
cd ..
6065
```
6166

62-
By default, the model expects the dataset in the `<repository>/dataset/raw` folder, and the processed dataset will be available in the `<repository>/dataset/processed folder`. If your dataset is elsewhere, please provide the correct paths with the `--raw_data_dir` and `--proc_data_dir` command line arguments for the 'split_dataset.py' and `process_dataset.py`. You can also use '--help' argument to see more details about the scripts.
63-
64-
As a result of running this script
65-
- numpy binary files `X_train.npy`, `Y_train.npy` (training dataset files) are created under `--proc_data_dir`
66-
- under `/test_inputs/` subfolder of the processed dataset folder test audios, such as `X_test_audio1168.npy` , are created
67+
As a result of running this script, the dataset is created in `--proc_data_dir`:
68+
- the training dataset files `X_train.npy`, `Y_train.npy` and the validation dataset files `X_dev.npy`, `Y_dev.npy`are binary numpy files
69+
- the audio inputs for testing (such as `X_test_NaturalTalking_04.npy`) are under the `/test_inputs/` subfolder
6770

71+
There rest of the folders in `--proc_data_dir` (e.g. `/dev_inputs/` or `/train/`) can be ignored (they are a side effect of the preprocessing script).
6872

69-
## 3. Learn motion representation by AutoEncoder and Encode the datset
70-
71-
Create a directory to save training checkpoints such as `chkpt/` and use it as CHKPT_DIR parameter.
72-
#### Learn dataset encoding and encode the training and validation datasets
73-
```sh
74-
python motion_repr_learning/ae/learn_ae_n_encode_dataset.py --data_dir <path/to/your/dataset> --layer1_width 40
73+
## 3. Learn motion representation by AutoEncoder and encode the training and validation datasets
74+
```python
75+
python motion_repr_learning/ae/learn_ae_n_encode_dataset.py --layer1_width DIM
7576
```
77+
There are several parameters that can be modified in the `config.yaml` file or through the command line, see `config.py` for details.
78+
The optimal dimensionality (DIM) in our experiment was 40.
7679

77-
The optimal dimensionality (DIM) in our experiment was 40
78-
80+
More information can be found in the folder `motion_repr_learning`
7981

8082
## 4. Learn speech-driven gesture generation model
8183

@@ -97,15 +99,14 @@ python predict.py MODEL_NAME.hdf5 INPUT_SPEECH_FILE OUTPUT_GESTURE_FILE
9799

98100
```sh
99101
# Usage example
100-
python predict.py model.hdf5 data/test_inputs/X_test_audio1168.npy data/test_inputs/predict_1168_20fps.txt
102+
python predict.py model.hdf5 data/test_inputs/X_test_NaturalTalking_04.npy data/test_inputs/predict_04_20fps.txt
101103
```
102104

105+
The predicted gestures have to be decoded with `decode.py`, which reuses the config from step 3.
103106
```sh
104-
# You need to decode the gestures
105-
python motion_repr_learning/ae/decode.py DATA_DIR ENCODED_PREDICTION_FILE DECODED_GESTURE_FILE -restore=True -pretrain=False -layer1_width=DIM -chkpt_dir=CHKPT_DIR -batch_size=8
107+
python motion_repr_learning/ae/decode.py python decode.py -input_file INPUT_FILE -output_file OUTPUT_FILE --layer1_width DIM --batch_size=8
106108
```
107109

108-
109110
## 6. Quantitative evaluation
110111
Use scripts in the `evaluation` folder of this directory.
111112

config.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,34 @@
33

44
# NOTE: the global variable 'args' for accessing the config parameters from other modules
55
# is defined at the very bottom of this file
6-
7-
# Modify this function to set the default home directory for this repo
8-
def home_out(path):
9-
return os.path.join(os.environ['HOME'], 'tmp', 'MoCap', path)
10-
116
def construct_config_parser():
127
parser = ArgumentParser(args_for_setting_config_path = ['-config'],
138
default_config_files = ['./config.yaml'],
149
config_file_parser_class = YAMLConfigFileParser)
1510

1611
parser.add('--seed', type=int, help='Random seed')
1712

18-
# ---- The data directories ----
13+
# ---- Data directories ----
14+
15+
parser.add('--data_dir', help='The directory with the preprocessed dataset')
16+
parser.add('--summary_dir', help='Directory for saving the summary data')
17+
parser.add('--chkpt_dir', help='Directory for saving the model checkpoints')
18+
parser.add('--results_file', help='File for saving the results of the experiments')
19+
20+
# ---- Input/output files for 'decode.py' only ----
1921

20-
parser.add('-data_dir', '--data_dir', required=True,
21-
help='The directory with the preprocessed dataset')
22-
parser.add('--summary_dir', default=home_out('summaries_exp'),
23-
help='Directory for saving the summary data')
24-
parser.add('--chkpt_dir', default=home_out('chkpts_exp'),
25-
help='Directory for saving the model checkpoints')
26-
parser.add('--results_file', default=home_out('results.txt'),
27-
help='File for saving the results of the experiments')
22+
parser.add('-input_file', default=None,
23+
help="The encoded prediction file that will be decoded (only used in 'decode.py')")
24+
parser.add('-output_file', default=None,
25+
help="The output file where the decoded gesture will be stored (only used in 'decode.py')")
2826

2927
# ---- Flags ----
3028

31-
parser.add('-pretrain', '--pretrain_network', action='store_true',
29+
parser.add('-pretrain', '--pretrain_network', action='store_true',
3230
help='If set, pretrain the model in a layerwise manner')
33-
parser.add('-load_model', '--load_model_from_checkpoint', action='store_true',
31+
parser.add('-load_model', '--load_model_from_checkpoint', action='store_true',
3432
help='If set, load the model from a checkpoint')
35-
parser.add('-no_early_stopping', '--no_early_stopping', action='store_false',
33+
parser.add('-no_early_stopping', '--no_early_stopping', action='store_false',
3634
help='If set, disable early stopping')
3735

3836
# ---- Network architecture ---

config.yaml

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# See config.py for details about these parameters.
22

3-
# NOTE: The defaults checkpoint, result and dataset directories are set in the code.
3+
data_dir: "./dataset/processed/"
4+
summary_dir: "./results/summaries/"
5+
chkpt_dir: "./results/checkpoints/"
6+
results_file: "./results/results.txt"
47

58
seed: 123456
69

7-
#-------------------------------------------------------
8-
# These boolean flags can be enabled by supplying them |
9-
# through the command-line or uncommenting them below |
10-
#-------------------------------------------------------
11-
1210
delta_for_early_stopping: 0.5
1311

1412
# ---- Network architecture ----
@@ -31,9 +29,8 @@ lr: 0.0001
3129
pretraining_lr: 0.001
3230

3331

34-
#-----------------------------------------------------
35-
# Weight decay is disabled by default. |
36-
# You can enable it by setting its multiplier below: |
37-
#-----------------------------------------------------
32+
33+
# Weight decay is disabled by default.
34+
# You can enable it by setting its multiplier below:
3835

3936
# weight_decay: <some value>

motion_repr_learning/ae/decode.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,33 @@
22
This file contains a usage script, intended to test using interface.
33
Developed by Taras Kucherenko (tarask@kth.se)
44
"""
5+
import sys
6+
sys.path.append('.')
7+
import numpy as np
58

69
import train as tr
7-
import utils.data as dt
8-
import utils.flags as fl
910
from learn_ae_n_encode_dataset import create_nn, prepare_motion_data
10-
11-
import numpy as np
12-
13-
import sys
14-
15-
DATA_DIR = sys.argv[1]
16-
TEST_FILE = sys.argv[2]
17-
OUTPUT_FILE = sys.argv[3]
11+
from config import args
1812

1913
if __name__ == '__main__':
14+
# Make sure that the two mandatory arguments are provided
15+
if args.input_file is None or args.output_file is None:
16+
print("Usage: python decode.py -input_file INPUT_FILE -output_file OUTPUT_FILE \n" + \
17+
"Where INPUT_FILE is the encoded prediction file and OUTPUT_FILE is the file in which the decoded gestures will be saved.")
18+
exit(-1)
19+
20+
# For decoding these arguments are always False and True
21+
args.pretrain_network = False
22+
args.load_model_from_checkpoint = True
2023

2124
# Get the data
22-
Y_train_normalized, Y_train, Y_dev_normalized, max_val, mean_pose = prepare_motion_data(DATA_DIR)
25+
Y_train_normalized, Y_train, Y_dev_normalized, max_val, mean_pose = prepare_motion_data(args.data_dir)
2326

2427
# Train the network
25-
nn = create_nn(Y_train_normalized, Y_dev_normalized, max_val, mean_pose, restoring=True)
28+
nn = create_nn(Y_train_normalized, Y_dev_normalized, max_val, mean_pose)
2629

2730
# Read the encoding
28-
encoding = np.loadtxt(TEST_FILE)
31+
encoding = np.loadtxt(args.input_file)
2932

3033
print(encoding.shape)
3134

@@ -34,7 +37,7 @@
3437

3538
print(decoding.shape)
3639

37-
np.save(OUTPUT_FILE, decoding)
40+
np.save(args.output_file, decoding)
3841

3942
# Close Tf session
4043
nn.session.close()

motion_repr_learning/ae/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import time
11+
from os.path import join, abspath
1112
import tensorflow as tf
1213
from tensorflow.python import debug as tf_debug
1314
import numpy as np
@@ -162,7 +163,7 @@ def learning(data, data_info, just_restore=False):
162163

163164
# Create a saver
164165
saver = tf.train.Saver(write_version=tf.train.SaverDef.V2)
165-
chkpt_file = args.chkpt_dir + '/chkpt-final'
166+
chkpt_file = abspath(join(args.chkpt_dir, 'chkpt-final'))
166167

167168
# restore model, if needed
168169
if args.load_model_from_checkpoint:

0 commit comments

Comments
 (0)