Skip to content

Commit 6e10d4e

Browse files
update glove
1 parent 7b8c194 commit 6e10d4e

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

nlp_class2/glove.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from datetime import datetime
1212
from sklearn.utils import shuffle
13-
from word2vec import get_wikipedia_data, find_analogies
13+
from word2vec import get_wikipedia_data, find_analogies, get_sentences_with_word2idx_limit_vocab
1414

1515
# Experiments
1616
# previous results did not make sense b/c X was built incorrectly
@@ -260,8 +260,11 @@ def save(self, fn):
260260
np.savez(fn, *arrays)
261261

262262

263-
def main(we_file, w2i_file, n_files=50):
264-
cc_matrix = "cc_matrix_%s.npy" % n_files
263+
def main(we_file, w2i_file, use_brown=True, n_files=50):
264+
if use_brown:
265+
cc_matrix = "cc_matrix_brown.npy"
266+
else:
267+
cc_matrix = "cc_matrix_%s.npy" % n_files
265268

266269
# hacky way of checking if we need to re-load the raw data or not
267270
# remember, only the co-occurrence matrix is needed for training
@@ -270,7 +273,19 @@ def main(we_file, w2i_file, n_files=50):
270273
word2idx = json.load(f)
271274
sentences = [] # dummy - we won't actually use it
272275
else:
273-
sentences, word2idx = get_wikipedia_data(n_files=n_files, n_vocab=2000)
276+
if use_brown:
277+
keep_words = set([
278+
'king', 'man', 'woman',
279+
'france', 'paris', 'london', 'rome', 'italy', 'britain', 'england',
280+
'french', 'english', 'japan', 'japanese', 'chinese', 'italian',
281+
'australia', 'australian', 'december', 'november', 'june',
282+
'january', 'february', 'march', 'april', 'may', 'july', 'august',
283+
'september', 'october',
284+
])
285+
sentences, word2idx = get_sentences_with_word2idx_limit_vocab(keep_words=keep_words)
286+
else:
287+
sentences, word2idx = get_wikipedia_data(n_files=n_files, n_vocab=2000)
288+
274289
with open(w2i_file, 'w') as f:
275290
json.dump(word2idx, f)
276291

@@ -282,17 +297,19 @@ def main(we_file, w2i_file, n_files=50):
282297
cc_matrix=cc_matrix,
283298
learning_rate=3*10e-5,
284299
reg=0.01,
285-
epochs=2000,
286-
gd=True,
287-
use_theano=True
300+
epochs=10,
301+
gd=False,
302+
use_theano=False
288303
) # gradient descent
289304
model.save(we_file)
290305

291306

292307
if __name__ == '__main__':
293-
we = 'glove_model_50.npz'
294-
w2i = 'glove_word2idx_50.json'
295-
main(we, w2i)
308+
# we = 'glove_model_50.npz'
309+
# w2i = 'glove_word2idx_50.json'
310+
we = 'glove_model_brown.npz'
311+
w2i = 'glove_word2idx_brown.json'
312+
main(we, w2i, use_brown=True)
296313
for concat in (True, False):
297314
print "** concat:", concat
298315
find_analogies('king', 'man', 'woman', concat, we, w2i)

rnn_class/brown.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from nltk.corpus import brown
22
import operator
33

4+
KEEP_WORDS = set([
5+
'king', 'man', 'queen', 'woman',
6+
'italy', 'rome', 'france', 'paris',
7+
'london', 'britain', 'england',
8+
])
9+
410

511
def get_sentences():
612
# returns 57340 of the Brown corpus
@@ -29,7 +35,7 @@ def get_sentences_with_word2idx():
2935
return indexed_sentences, word2idx
3036

3137

32-
def get_sentences_with_word2idx_limit_vocab(n_vocab=2000):
38+
def get_sentences_with_word2idx_limit_vocab(n_vocab=2000, keep_words=KEEP_WORDS):
3339
sentences = get_sentences()
3440
indexed_sentences = []
3541

@@ -65,14 +71,8 @@ def get_sentences_with_word2idx_limit_vocab(n_vocab=2000):
6571
# set all the words I want to keep to infinity
6672
# so that they are included when I pick the most
6773
# common words
68-
word_idx_count[word2idx['king']] = float('inf')
69-
word_idx_count[word2idx['queen']] = float('inf')
70-
word_idx_count[word2idx['man']] = float('inf')
71-
word_idx_count[word2idx['woman']] = float('inf')
72-
word_idx_count[word2idx['italy']] = float('inf')
73-
word_idx_count[word2idx['rome']] = float('inf')
74-
word_idx_count[word2idx['france']] = float('inf')
75-
word_idx_count[word2idx['paris']] = float('inf')
74+
for word in keep_words:
75+
word_idx_count[word2idx[word]] = float('inf')
7676

7777
sorted_word_idx_count = sorted(word_idx_count.items(), key=operator.itemgetter(1), reverse=True)
7878
word2idx_small = {}
@@ -90,14 +90,8 @@ def get_sentences_with_word2idx_limit_vocab(n_vocab=2000):
9090

9191
assert('START' in word2idx_small)
9292
assert('END' in word2idx_small)
93-
assert('king' in word2idx_small)
94-
assert('queen' in word2idx_small)
95-
assert('man' in word2idx_small)
96-
assert('woman' in word2idx_small)
97-
assert('italy' in word2idx_small)
98-
assert('rome' in word2idx_small)
99-
assert('france' in word2idx_small)
100-
assert('paris' in word2idx_small)
93+
for word in keep_words:
94+
assert(word in word2idx_small)
10195

10296
# map old idx to new idx
10397
sentences_small = []

0 commit comments

Comments
 (0)