Skip to content

Commit

Permalink
Fixed sign on actor loss, minor
Browse files Browse the repository at this point in the history
refactoring in model

still looking for hyperparams that work!!
  • Loading branch information
pemami4911 committed Oct 22, 2017
1 parent ef76957 commit 55ee1e9
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 95 deletions.
26 changes: 26 additions & 0 deletions hyperparam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import subprocess
import numpy as np
import sys
from math import floor, log10


if __name__ == '__main__':
exp_i = sys.argv[1]
#rand_seed = int(sys.argv[2])

#np.random.seed(rand_seed)


exps = [4]
num = np.arange(1, 9)

#num_trials = 25

seeds = [123, 343]

for rs in seeds:
for exp in exps:
for n in num:

lr = n * (1./(10 ** exp))
subprocess.call(["./tune_hyper.sh", str(lr), str(rs), exp_i])
22 changes: 10 additions & 12 deletions main.sh
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
#!/bin/bash

TASK='tsp_20'
DROPOUT=0.0
BEAM_SIZE=1
EMBEDDING_DIM=128
HIDDEN_DIM=128
BATCH_SIZE=128
ACTOR_NET_LR=1e-5
CRITIC_NET_LR=1e-4
ACTOR_NET_LR=1e-3
CRITIC_NET_LR=1e-3
ACTOR_LR_DECAY_RATE=0.96
ACTOR_LR_DECAY_STEP=5000
CRITIC_LR_DECAY_RATE=0.96
CRITIC_LR_DECAY_STEP=5000
N_PROCESS_BLOCKS=3
N_GLIMPSES=1
N_EPOCHS=100
N_EPOCHS=500
EPOCH_START=0
MAX_GRAD_NORM=2.0
RANDOM_SEED=$1
RUN_NAME="tsp_20-seed-$RANDOM_SEED"
TRAIN_SIZE=500000
VAL_SIZE=1500
LOAD_PATH="outputs/tsp_20/tsp_20-seed-320-entropy-5e4/epoch-3.pt"
MAX_GRAD_NORM=1.0
RANDOM_SEED=1000
RUN_NAME="$1-$ACTOR_NET_LR-seed-$RANDOM_SEED"
TRAIN_SIZE=1280000
VAL_SIZE=1000
LOAD_PATH="outputs/tsp_20/LR3-$ACTOR_NET_LR-seed-$RANDOM_SEED/epoch-5.pt"
USE_CUDA=True
DISABLE_TENSORBOARD=False
ENTROPY_COEFF=0.00
REWARD_SCALE=1
USE_TANH=True

./trainer.py --task $TASK --dropout $DROPOUT --beam_size $BEAM_SIZE --actor_net_lr $ACTOR_NET_LR --critic_net_lr $CRITIC_NET_LR --n_epochs $N_EPOCHS --random_seed $RANDOM_SEED --max_grad_norm $MAX_GRAD_NORM --run_name $RUN_NAME --epoch_start $EPOCH_START --train_size $TRAIN_SIZE --n_process_blocks $N_PROCESS_BLOCKS --batch_size $BATCH_SIZE --actor_lr_decay_rate $ACTOR_LR_DECAY_RATE --actor_lr_decay_step $ACTOR_LR_DECAY_STEP --critic_lr_decay_rate $CRITIC_LR_DECAY_RATE --critic_lr_decay_step $CRITIC_LR_DECAY_STEP --embedding_dim $EMBEDDING_DIM --hidden_dim $HIDDEN_DIM --val_size $VAL_SIZE --n_glimpses $N_GLIMPSES --use_cuda $USE_CUDA --disable_tensorboard $DISABLE_TENSORBOARD --entropy_coeff $ENTROPY_COEFF --reward_scale $REWARD_SCALE --use_tanh $USE_TANH
./trainer.py --task $TASK --beam_size $BEAM_SIZE --actor_net_lr $ACTOR_NET_LR --critic_net_lr $CRITIC_NET_LR --n_epochs $N_EPOCHS --random_seed $RANDOM_SEED --max_grad_norm $MAX_GRAD_NORM --run_name $RUN_NAME --epoch_start $EPOCH_START --train_size $TRAIN_SIZE --n_process_blocks $N_PROCESS_BLOCKS --batch_size $BATCH_SIZE --actor_lr_decay_rate $ACTOR_LR_DECAY_RATE --actor_lr_decay_step $ACTOR_LR_DECAY_STEP --critic_lr_decay_rate $CRITIC_LR_DECAY_RATE --critic_lr_decay_step $CRITIC_LR_DECAY_STEP --embedding_dim $EMBEDDING_DIM --hidden_dim $HIDDEN_DIM --val_size $VAL_SIZE --n_glimpses $N_GLIMPSES --use_cuda $USE_CUDA --disable_tensorboard $DISABLE_TENSORBOARD --reward_scale $REWARD_SCALE --use_tanh $USE_TANH

121 changes: 56 additions & 65 deletions neural_combinatorial_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,36 @@
class Encoder(nn.Module):
"""Maps a graph represented as an input sequence
to a hidden vector"""
def __init__(self, input_dim, hidden_dim, n_layers, dropout, use_cuda):
def __init__(self, input_dim, hidden_dim, use_cuda):
super(Encoder, self).__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers,
dropout=dropout)
self.lstm = nn.LSTM(input_dim, hidden_dim)
self.use_cuda = use_cuda

self.enc_init_state = self.init_hidden(hidden_dim)

def forward(self, x, hidden):
output, hidden = self.lstm(x, hidden)
return output, hidden

def init_hidden(self, inputs):
batch_size = inputs.size(1)
hx = Variable(torch.zeros(self.n_layers,
batch_size,
self.hidden_dim),
requires_grad=False)
cx = Variable(torch.zeros(self.n_layers,
batch_size,
self.hidden_dim),
requires_grad=False)
def init_hidden(self, hidden_dim):
"""Trainable initial hidden state"""
enc_init_hx = torch.FloatTensor(hidden_dim)
if self.use_cuda:
return hx.cuda(), cx.cuda()
else:
return hx, cx
enc_init_hx = enc_init_hx.cuda()

enc_init_hx = nn.Parameter(enc_init_hx)
enc_init_hx.data.uniform_(-(1. / math.sqrt(hidden_dim)),
1. / math.sqrt(hidden_dim))

enc_init_cx = torch.FloatTensor(hidden_dim)
if self.use_cuda:
enc_init_cx = enc_init_cx.cuda()

enc_init_cx = nn.Parameter(enc_init_cx)
enc_init_cx.data.uniform_(-(1. / math.sqrt(hidden_dim)),
1. / math.sqrt(hidden_dim))
return (enc_init_hx, enc_init_cx)


class Attention(nn.Module):
"""A generic attention module for a decoder in seq2seq"""
Expand Down Expand Up @@ -91,7 +95,6 @@ def __init__(self,
terminating_symbol,
use_tanh,
decode_type,
dropout,
n_glimpses=1,
beam_size=0,
use_cuda=True):
Expand All @@ -106,19 +109,13 @@ def __init__(self,
self.beam_size = beam_size
self.use_cuda = use_cuda

self.dropout = nn.Dropout(p=dropout)
self.input_weights = nn.Linear(embedding_dim, 4 * hidden_dim)
self.hidden_weights = nn.Linear(hidden_dim, 4 * hidden_dim)
self.linear_hidden_out = nn.Linear(hidden_dim * 2, hidden_dim)

self.pointer = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration, use_cuda=self.use_cuda)
self.glimpse = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration, use_cuda=self.use_cuda)
self.glimpse = Attention(hidden_dim, use_tanh=False, use_cuda=self.use_cuda)
self.sm = nn.Softmax()

# self.neg_inf = Variable(torch.Tensor([float('-inf')]))
# if self.use_cuda:
# self.neg_inf = self.neg_inf.cuda()

def apply_mask_to_logits(self, step, logits, mask, prev_idxs):
if mask is None:
mask = torch.zeros(logits.size()).byte()
Expand All @@ -128,8 +125,6 @@ def apply_mask_to_logits(self, step, logits, mask, prev_idxs):
# to prevent them from being reselected.
# Or, allow re-selection and penalize in the objective function
if prev_idxs is not None:
#mask_size = logits.size(0) * step
#n_inf = self.neg_inf.repeat(mask_size)
# set most recently selected idx values to 1
mask[[x for x in range(logits.size(0))],
prev_idxs.data] = 1
Expand All @@ -141,15 +136,13 @@ def forward(self, decoder_input, embedded_inputs, hidden, context):
Args:
decoder_input: The initial input to the decoder
size is [batch_size x embedding_dim]. Trainable parameter.
embedded_inputs: [sourceL x batch_size x embeddign_dim]
embedded_inputs: [sourceL x batch_size x embedding_dim]
hidden: the prev hidden state, size is [batch_size x hidden_dim].
Initially this is set to (enc_h[-1], enc_c[-1])
context: encoder outputs, [sourceL x batch_size x hidden_dim]
"""
def recurrence(x, hidden, logit_mask, prev_idxs, step):

x = self.dropout(x)

hx, cx = hidden # batch_size x hidden_dim

gates = self.input_weights(x) + self.hidden_weights(hx)
Expand All @@ -171,7 +164,7 @@ def recurrence(x, hidden, logit_mask, prev_idxs, step):
# [batch_size x h_dim x 1]
g_l = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2)
_, logits = self.pointer(g_l, context)
#h_tilde = F.tanh(self.linear_hidden_out(torch.cat((h_tilde, hy), 1)))

logits, logit_mask = self.apply_mask_to_logits(step, logits, logit_mask, prev_idxs)
probs = self.sm(logits)
return hy, cy, probs, logit_mask
Expand All @@ -186,7 +179,6 @@ def recurrence(x, hidden, logit_mask, prev_idxs, step):

if self.decode_type == "stochastic":
for i in steps:

hx, cx, probs, mask = recurrence(decoder_input, hidden, mask, idxs, i)
hidden = (hx, cx)
# select the next inputs for the decoder [batch_size x hidden_dim]
Expand Down Expand Up @@ -316,20 +308,21 @@ class PointerNetwork(nn.Module):
"""The pointer network, which is the core seq2seq
model"""
def __init__(self,
encoder,
embedding_dim,
hidden_dim,
max_decoding_len,
terminating_symbol,
n_glimpses,
tanh_exploration,
use_tanh,
dropout,
beam_size,
use_cuda):
super(PointerNetwork, self).__init__()

self.encoder = encoder
self.encoder = Encoder(
embedding_dim,
hidden_dim,
use_cuda)

self.decoder = Decoder(
embedding_dim,
Expand All @@ -339,11 +332,11 @@ def __init__(self,
use_tanh=use_tanh,
terminating_symbol=terminating_symbol,
decode_type="stochastic",
dropout=dropout,
n_glimpses=n_glimpses,
beam_size=beam_size,
use_cuda=use_cuda)


# Trainable initial hidden states
dec_in_0 = torch.FloatTensor(embedding_dim)
if use_cuda:
dec_in_0 = dec_in_0.cuda()
Expand All @@ -358,11 +351,13 @@ def forward(self, inputs):
inputs: [sourceL x batch_size x embedding_dim]
"""

encoder_h0, encoder_c0 = self.encoder.init_hidden(inputs)

# encoder forward pass
enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_h0, encoder_c0))
(encoder_hx, encoder_cx) = self.encoder.enc_init_state
encoder_hx = encoder_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)
encoder_cx = encoder_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)

# encoder forward pass
enc_h, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx))

dec_init_state = (enc_h_t[-1], enc_c_t[-1])

# repeat decoder_in_0 across batch
Expand All @@ -371,15 +366,15 @@ def forward(self, inputs):
(pointer_probs, input_idxs), dec_hidden_t = self.decoder(decoder_input,
inputs,
dec_init_state,
enc_outputs)
enc_h)

return pointer_probs, input_idxs


class CriticNetwork(nn.Module):
"""Useful as a baseline in REINFORCE updates"""
def __init__(self,
encoder,
embedding_dim,
hidden_dim,
n_process_block_iters,
tanh_exploration,
Expand All @@ -390,8 +385,13 @@ def __init__(self,
self.hidden_dim = hidden_dim
self.n_process_block_iters = n_process_block_iters

self.encoder = encoder
self.process_block = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration, use_cuda=use_cuda)
self.encoder = Encoder(
embedding_dim,
hidden_dim,
use_cuda)

self.process_block = Attention(hidden_dim,
use_tanh=use_tanh, C=tanh_exploration, use_cuda=use_cuda)
self.sm = nn.Softmax()
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
Expand All @@ -404,11 +404,14 @@ def forward(self, inputs):
Args:
inputs: [embedding_dim x batch_size x sourceL] of embedded inputs
"""
encoder_h0, encoder_c0 = self.encoder.init_hidden(inputs)

(encoder_hx, encoder_cx) = self.encoder.enc_init_state
encoder_hx = encoder_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)
encoder_cx = encoder_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)

# encoder forward pass
enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_h0, encoder_c0))

enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx))
# grab the hidden state and process it via the process block
process_block_state = enc_h_t[-1]
for i in range(self.n_process_block_iters):
Expand All @@ -432,10 +435,8 @@ def __init__(self,
terminating_symbol,
n_glimpses,
n_process_block_iters,
n_layers,
tanh_exploration,
use_tanh,
dropout,
beam_size,
objective_fn,
is_train,
Expand All @@ -446,41 +447,31 @@ def __init__(self,
self.is_train = is_train
self.use_cuda = use_cuda

self.encoder = Encoder(
embedding_dim,
hidden_dim,
n_layers,
dropout,
use_cuda)

self.actor_net = PointerNetwork(
self.encoder,
embedding_dim,
hidden_dim,
max_decoding_len,
terminating_symbol,
n_glimpses,
tanh_exploration,
use_tanh,
dropout,
beam_size,
use_cuda)

self.critic_net = CriticNetwork(
self.encoder,
embedding_dim,
hidden_dim,
n_process_block_iters,
tanh_exploration,
use_tanh,
False,
use_cuda)

embedding_ = torch.FloatTensor(input_dim,
embedding_dim)

if self.use_cuda:
embedding_ = embedding_.cuda()

self.embedding = nn.Parameter(embedding_)

self.embedding = nn.Parameter(embedding_)
self.embedding.data.uniform_(-(1. / math.sqrt(embedding_dim)),
1. / math.sqrt(embedding_dim))

Expand Down
Loading

0 comments on commit 55ee1e9

Please sign in to comment.