Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions decoding/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
#-----------------------------------------------------------------------------#
# Specify model and dictionary locations here
#-----------------------------------------------------------------------------#
path_to_model = '/u/rkiros/research/semhash/models/toydec.npz'
path_to_dictionary = '/ais/gobi3/u/rkiros/flickr8k/dictionary.pkl'
PATH_TO_MODEL = '/u/rkiros/research/semhash/models/toydec.npz'
PATH_TO_PARAMS = '/u/rkiros/research/semhash/models/toydec.npz'
PATH_TO_DICTIONARY = '/ais/gobi3/u/rkiros/flickr8k/dictionary.pkl'
#-----------------------------------------------------------------------------#

def load_model():
def load_model(
path_to_model=PATH_TO_MODEL, # model opts (.pkl)
path_to_params=PATH_TO_PARAMS, # model params (.npz)
path_to_dictionary=PATH_TO_DICTIONARY
):
"""
Load a trained model for decoding
"""
Expand All @@ -45,7 +50,7 @@ def load_model():
# Load parameters
print 'Loading model parameters...'
params = init_params(options)
params = load_params(path_to_model, params)
params = load_params(path_to_params, params)
tparams = init_tparams(params)

# Sampler.
Expand Down
21 changes: 14 additions & 7 deletions training/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
#-----------------------------------------------------------------------------#
# Specify model and dictionary locations here
#-----------------------------------------------------------------------------#
path_to_model = '/u/rkiros/research/semhash/models/toy.npz'
path_to_dictionary = '/ais/gobi3/u/rkiros/bookgen/book_dictionary_large.pkl'
path_to_word2vec = '/ais/gobi3/u/rkiros/word2vec/GoogleNews-vectors-negative300.bin'
PATH_TO_MODEL = '/u/rkiros/research/semhash/models/toy.pkl'
PATH_TO_PARAMS = '/u/rkiros/research/semhash/models/toy.npz'
PATH_TO_DICTIONARY = '/ais/gobi3/u/rkiros/bookgen/book_dictionary_large.pkl'
PATH_TO_WORD2VEC = '/ais/gobi3/u/rkiros/word2vec/GoogleNews-vectors-negative300.bin'
#-----------------------------------------------------------------------------#

def load_model(embed_map=None):
def load_model(
embed_map=None,
path_to_model=PATH_TO_MODEL, # model opts (.pkl)
path_to_params=PATH_TO_PARAMS, # model params (.npz)
path_to_dictionary=PATH_TO_DICTIONARY,
path_to_word2vec=PATH_TO_WORD2VEC
):
"""
Load all model components + apply vocab expansion
"""
Expand All @@ -46,13 +53,13 @@ def load_model(embed_map=None):

# Load model options
print 'Loading model options...'
with open('%s.pkl'%path_to_model, 'rb') as f:
with open(path_to_model, 'rb') as f:
options = pkl.load(f)

# Load parameters
print 'Loading model parameters...'
params = init_params(options)
params = load_params(path_to_model, params)
params = load_params(path_to_params, params)
tparams = init_tparams(params)

# Extractor functions
Expand Down Expand Up @@ -149,7 +156,7 @@ def preprocess(text):
X.append(result)
return X

def load_googlenews_vectors():
def load_googlenews_vectors(path_to_word2vec=PATH_TO_WORD2VEC):
"""
load the word2vec GoogleNews vectors
"""
Expand Down
1 change: 1 addition & 0 deletions training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import theano
import theano.tensor as tensor
import numpy
import warnings

from collections import OrderedDict

Expand Down