Skip to content

Commit

Permalink
zero pad for inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Jan 24, 2016
1 parent 6d1ab6e commit b5261b3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
3 changes: 1 addition & 2 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def create_vocabulary(vocabulary_path, context, max_vocabulary_size,
vocab[w] += 1
vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
if len(vocab_list) > max_vocabulary_size:
import ipdb; ipdb.set_trace()
vocab_list = vocab_list[:max_vocabulary_size]
keys = [int(key[len(_ENTITY):]) for key in vocab.keys() if _ENTITY in key]
for key in set(range(max(keys))) - set(keys):
Expand Down Expand Up @@ -234,7 +233,7 @@ def load_vocab(data_dir, dataset_name, vocab_size):

def load_dataset(data_dir, dataset_name, vocab_size):
train_files = os.path.join(data_dir, dataset_name, "questions",
"training", "*.question.ids%s" % (vocab_size))
"training", "*.question.ids%s_*" % (vocab_size))
for fname in glob(train_files):
with open(fname) as f:
yield f.read().split("\n\n")
Expand Down
4 changes: 2 additions & 2 deletions model/deep_lstm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
from tensorflow.models.rnn import rnn, rnn_cell

from utils import pad_array
from base_model import Model
from cells import LSTMCell, MultiRNNCellWithSkipConn, DropoutWrapper
from data_utils import load_vocab, load_dataset
Expand Down Expand Up @@ -107,9 +108,8 @@ def train(self, sess, epoch=25, learning_rate=0.0002, momentum=0.9,
answers.append(answers)

import ipdb; ipdb.set_trace()
inputs = pad_array(contexts, self.max_nsteps)
cost = sess.run([loss], feed_dict={})

#self.model.

def test(self, voab_size):
self.prepare_model(data_dir, dataset_name, vocab_size)
5 changes: 5 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import pprint
import numpy as np

pp = pprint.PrettyPrinter()

def pad_array(array, width):
map(lambda x: x.extend([0]*(width-len(x))), array)
return np.array(array)

0 comments on commit b5261b3

Please sign in to comment.