Skip to content

Commit 7d44ad8

Browse files
committed
initial commit
0 parents  commit 7d44ad8

File tree

10 files changed

+418
-0
lines changed

10 files changed

+418
-0
lines changed

.gitignore

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Data
2+
data/hand
3+
data/gaze
4+
data/*.png
5+
samples
6+
outputs
7+
8+
# Log
9+
logs
10+
11+
# ETC
12+
paper.pdf
13+
14+
# Created by https://www.gitignore.io/api/python,vim
15+
16+
### Python ###
17+
# Byte-compiled / optimized / DLL files
18+
__pycache__/
19+
*.py[cod]
20+
*$py.class
21+
22+
# C extensions
23+
*.so
24+
25+
# Distribution / packaging
26+
.Python
27+
env/
28+
build/
29+
develop-eggs/
30+
dist/
31+
downloads/
32+
eggs/
33+
.eggs/
34+
lib/
35+
lib64/
36+
parts/
37+
sdist/
38+
var/
39+
wheels/
40+
*.egg-info/
41+
.installed.cfg
42+
*.egg
43+
44+
# PyInstaller
45+
# Usually these files are written by a python script from a template
46+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
47+
*.manifest
48+
*.spec
49+
50+
# Installer logs
51+
pip-log.txt
52+
pip-delete-this-directory.txt
53+
54+
# Unit test / coverage reports
55+
htmlcov/
56+
.tox/
57+
.coverage
58+
.coverage.*
59+
.cache
60+
nosetests.xml
61+
coverage.xml
62+
*,cover
63+
.hypothesis/
64+
65+
# Translations
66+
*.mo
67+
*.pot
68+
69+
# Django stuff:
70+
*.log
71+
local_settings.py
72+
73+
# Flask stuff:
74+
instance/
75+
.webassets-cache
76+
77+
# Scrapy stuff:
78+
.scrapy
79+
80+
# Sphinx documentation
81+
docs/_build/
82+
83+
# PyBuilder
84+
target/
85+
86+
# Jupyter Notebook
87+
.ipynb_checkpoints
88+
89+
# pyenv
90+
.python-version
91+
92+
# celery beat schedule file
93+
celerybeat-schedule
94+
95+
# dotenv
96+
.env
97+
98+
# virtualenv
99+
.venv/
100+
venv/
101+
ENV/
102+
103+
# Spyder project settings
104+
.spyderproject
105+
106+
# Rope project settings
107+
.ropeproject
108+
109+
110+
### Vim ###
111+
# swap
112+
[._]*.s[a-v][a-z]
113+
[._]*.sw[a-p]
114+
[._]s[a-v][a-z]
115+
[._]sw[a-p]
116+
# session
117+
Session.vim
118+
# temporary
119+
.netrwhist
120+
*~
121+
# auto-generated tag files
122+
tags
123+
124+
# End of https://www.gitignore.io/api/python,vim

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Neural Combinatorial Optimization in Tensorflow
2+
3+
TensorFlow implementation of [Neural Combinatorial Optimization with Reinforcement Learning](http://arxiv.org/abs/1611.09940).
4+
5+
![model](./assets/model.png)
6+
7+
8+
## Requirements
9+
10+
- Python 2.7
11+
- [TensorFlow](https://www.tensorflow.org/) 0.12.0
12+
- [tqdm](https://github.com/tqdm/tqdm)
13+
14+
15+
## Usage
16+
17+
To train a model:
18+
19+
$ python main.py
20+
$ tensorboard --logdir=logs --host=0.0.0.0
21+
22+
To test a model:
23+
24+
$ python main.py --is_train=False
25+
26+
## Results
27+
28+
(in progress)
29+
30+
31+
## Author
32+
33+
Taehoon Kim / [@carpedm20](http://carpedm20.github.io)

assets/model.png

24.5 KB
Loading

config.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#-*- coding: utf-8 -*-
2+
import argparse
3+
4+
def str2bool(v):
5+
return v.lower() in ('true', '1')
6+
7+
arg_lists = []
8+
parser = argparse.ArgumentParser()
9+
10+
def add_argument_group(name):
11+
arg = parser.add_argument_group(name)
12+
arg_lists.append(arg)
13+
return arg
14+
15+
# Network
16+
net_arg = add_argument_group('Network')
17+
net_arg.add_argument('--hidden_dims', type=int, default=200, help='')
18+
net_arg.add_argument('--num_layer', type=int, default=2, help='')
19+
20+
# Data
21+
data_arg = add_argument_group('Data')
22+
data_arg.add_argument('--data_set', type=str, default='gaze')
23+
data_arg.add_argument('--data_dir', type=str, default='data')
24+
data_arg.add_argument('--input_height', type=int, default=35)
25+
data_arg.add_argument('--input_width', type=int, default=55)
26+
data_arg.add_argument('--task_name', type=str, default='TPS20')
27+
28+
# Training / test parameters
29+
train_arg = add_argument_group('Training')
30+
train_arg.add_argument('--is_train', type=str2bool, default=True, help='')
31+
train_arg.add_argument('--optimizer', type=str, default='rmsprop', help='')
32+
train_arg.add_argument('--max_step', type=int, default=10000, help='')
33+
train_arg.add_argument('--reg_scale', type=float, default=0.5, help='')
34+
train_arg.add_argument('--batch_size', type=int, default=512, help='')
35+
train_arg.add_argument('--learning_rate', type=float, default=0.001, help='')
36+
train_arg.add_argument('--checkpoint_secs', type=int, default=300, help='')
37+
train_arg.add_argument('--max_grad_norm', type=float, default=50, help='')
38+
39+
# Misc
40+
misc_arg = add_argument_group('Misc')
41+
misc_arg.add_argument('--log_step', type=int, default=20, help='')
42+
misc_arg.add_argument('--log_dir', type=str, default='logs')
43+
misc_arg.add_argument('--sample_dir', type=str, default='samples')
44+
misc_arg.add_argument('--output_dir', type=str, default='outputs')
45+
misc_arg.add_argument('--load_path', type=str, default='')
46+
misc_arg.add_argument('--debug', type=str2bool, default=False)
47+
misc_arg.add_argument('--gpu_memory_fraction', type=float, default=1.0)
48+
misc_arg.add_argument('--random_seed', type=int, default=123, help='')
49+
50+
def get_config():
51+
config, unparsed = parser.parse_known_args()
52+
return config, unparsed

data_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class DataLoader(object):
2+
def __init__(self):
3+
pass

layers.py

Whitespace-only changes.

main.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import sys
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
from trainer import Trainer
6+
from config import get_config
7+
from utils import prepare_dirs, save_config
8+
9+
config = None
10+
11+
def main(_):
12+
prepare_dirs(config)
13+
14+
rng = np.random.RandomState(config.random_seed)
15+
tf.set_random_seed(config.random_seed)
16+
17+
trainer = Trainer(config, rng)
18+
save_config(config.model_dir, config)
19+
20+
if config.is_train:
21+
trainer.train()
22+
else:
23+
if not config.load_path:
24+
raise Exception("[!] You should specify `load_path` to load a pretrained model")
25+
trainer.test()
26+
27+
if __name__ == "__main__":
28+
config, unparsed = get_config()
29+
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

model.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import tensorflow as tf
2+
from tensorflow.contrib.framework import arg_scope
3+
4+
from layers import *
5+
from utils import show_all_variables
6+
7+
class Model(object):
8+
def __init__(self, config, data_loader):
9+
self.data_loader = data_loader
10+
11+
self.task = config.task
12+
self.debug = config.debug
13+
self.config = config
14+
15+
self.input_height = config.input_height
16+
self.input_width = config.input_width
17+
self.input_channel = config.input_channel
18+
19+
self.reg_scale = config.reg_scale
20+
self.learning_rate = config.learning_rate
21+
self.max_grad_norm = config.max_grad_norm
22+
self.batch_size = config.batch_size
23+
24+
self.layer_dict = {}
25+
26+
self._build_placeholders()
27+
self._build_model()
28+
self._build_optim()
29+
30+
show_all_variables()
31+
32+
def _build_placeholders(self):
33+
self.inputs = tf.placeholder(tf.float32, name="inputs")
34+
self.lengths = tf.placeholder(tf.float32, name="lengths")
35+
self.targets = tf.placeholder(tf.float32, name="targets")
36+
self.is_train = tf.placeholder(tf.bool, name="is_train")
37+
38+
def _build_encoder(self):
39+
self.global_step = tf.Variable(0, trainable=False)
40+
41+
with tf.variable_scope("encoder"):
42+
self.cell = tf.nn.rnn_cell.LSTMCell(size)
43+
if num_layers > 1:
44+
self.cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
45+
46+
self.rnn = tf.nn.dynamic_rnn(self.cell, self.inputs, self.seq_length)
47+
48+
with tf.variable_scope("dencoder"):
49+
self.cell = tf.nn.rnn_cell.LSTMCell(size)
50+
if num_layers > 1:
51+
self.cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
52+
53+
self.rnn = tf.nn.dynamic_rnn(self.cell, self.inputs, self.seq_length)
54+
55+
def _build_optim(self):
56+
self.loss = tf.reduce_mean(self.output - self.targets)
57+
58+
self.learning_rate = tf.Variable(self.learning_rate)
59+
self.optim = tf.train.AdamOptimizer(self.learning_rate)

trainer.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
import numpy as np
3+
from tqdm import trange
4+
import tensorflow as tf
5+
from tensorflow.contrib.framework.python.ops import arg_scope
6+
7+
from model import Model
8+
from data_loader import DataLoader
9+
10+
class Trainer(object):
11+
def __init__(self, config, rng):
12+
self.config = config
13+
self.rng = rng
14+
15+
self.task = config.task
16+
self.model_dir = config.model_dir
17+
self.gpu_memory_fraction = config.gpu_memory_fraction
18+
19+
self.log_step = config.log_step
20+
self.max_step = config.max_step
21+
self.checkpoint_secs = config.checkpoint_secs
22+
23+
self.summary_ops = {}
24+
25+
self.model = Model(config, self.data_loader)
26+
self.data_loader = DataLoader(config, rng=self.rng)
27+
28+
self._build_session()
29+
30+
def _build_session(self):
31+
self.saver = tf.train.Saver()
32+
self.summary_writer = tf.summary.FileWriter(self.model_dir)
33+
34+
sv = tf.train.Supervisor(logdir=self.model_dir,
35+
is_chief=True,
36+
saver=self.saver,
37+
summary_op=None,
38+
summary_writer=self.summary_writer,
39+
save_summaries_secs=300,
40+
save_model_secs=self.checkpoint_secs,
41+
global_step=self.model.discrim_step)
42+
43+
gpu_options = tf.GPUOptions(
44+
per_process_gpu_memory_fraction=self.gpu_memory_fraction,
45+
allow_growth=True) # seems to be not working
46+
sess_config = tf.ConfigProto(allow_soft_placement=True,
47+
gpu_options=gpu_options)
48+
49+
self.sess = sv.prepare_or_wait_for_session(config=sess_config)
50+
51+
def train(self):
52+
print("[*] Training starts...")
53+
pass
54+
55+
def test(self):
56+
pass
57+
58+
def _inject_summary(self, tag, feed_dict, step):
59+
summaries = self.sess.run(self.summary_ops[tag], feed_dict)
60+
self.summary_writer.add_summary(summaries['summary'], step)
61+
62+
path = os.path.join(
63+
self.config.sample_model_dir, "{}.png".format(step))
64+
imwrite(path, img_tile(summaries['output'],
65+
tile_shape=self.config.sample_image_grid)[:,:,0])
66+
67+
def _get_summary_writer(self, result):
68+
if result['step'] % self.log_step == 0:
69+
return self.summary_writer
70+
else:
71+
return None

0 commit comments

Comments
 (0)