Skip to content

Commit

Permalink
create_binary now correctly subtokenizes predictions, and does not in…
Browse files Browse the repository at this point in the history
…clude unclear predictions
  • Loading branch information
ben-baran committed Aug 18, 2018
1 parent 3167c55 commit b5898d8
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions create_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from tqdm import tqdm
import re
from utils import ComboVocab, to_subtokenized_list
from utils import ComboVocab, to_subtokenized_list, subtokenize


n_ctx = 64 # size of context per side. This is decided earlier in the processing pipeline.
Expand All @@ -18,12 +18,15 @@
train_save_folder = 'data/train_tmp/'
val_save_folder = 'data/val_tmp/'
test_save_folder = 'data/test_tmp/'
max_uses = 20000
proportion_validation = 0.1
proportion_test = 0.1

train_context_files, val_context_files, test_context_files = {}, {}, {}
random.shuffle(file_seeks)
n_unclear_skips = 0


max_uses = 5000
proportion_validation = 0.1
proportion_test = 0.1

bar = tqdm(total = min(max_uses, len(file_seeks)), desc = 'processing', unit = 'ctx groups')
for seek_i, seek in enumerate(file_seeks):
Expand Down Expand Up @@ -51,22 +54,27 @@
context_files[n_contexts] = open(save_folder + '%d.bin' % n_contexts, 'wb')
fout = context_files[n_contexts]

variable_subtokens = combo_vocab.to_ids([data['variableName']])[:8] # maxes out number of subtokens to 8
variable_subtokens.extend(combo_vocab.to_ids(['<<PAD>>' for i in range(8 - len(variable_subtokens))]))
fout.write(struct.pack('<8I', *variable_subtokens))

for context in data['usage']:
context_a = to_subtokenized_list(context[:64])[-64:]
context_b = to_subtokenized_list(context[129:64:-1])[-64:]
fout.write(struct.pack('<64I', *combo_vocab.to_ids(context_a)))
fout.write(struct.pack('<64I', *combo_vocab.to_ids(context_b)))
# print("-" * 100)
# print(context)
# for o_token, t_token in zip(context_a, combo_vocab.to_tokens(combo_vocab.to_ids(context_a))):
# print(o_token, '-->', t_token)
# break
bar.update(1)
var_subtokens = subtokenize(data['variableName'])
var_subids = combo_vocab.to_ids(var_subtokens)[:8] # maxes out number of subtokens to 8

unclear_subtoken = False
for subtoken in combo_vocab.to_tokens(var_subids): # turn it back into a token to check, so that we can skip over unclear subtokens
if subtoken == '<unk>':
unclear_subtoken = True
n_unclear_skips += 1
break
if not unclear_subtoken:
var_subids.extend(combo_vocab.to_ids(['<<PAD>>' for i in range(8 - len(var_subids))]))
fout.write(struct.pack('<8I', *var_subids))

for context in data['usage']:
context_a = to_subtokenized_list(context[:64])[-64:]
context_b = to_subtokenized_list(context[129:64:-1])[-64:]
fout.write(struct.pack('<64I', *combo_vocab.to_ids(context_a)))
fout.write(struct.pack('<64I', *combo_vocab.to_ids(context_b)))

bar.update(1)
if seek_i == max_uses - 1:
break

Expand All @@ -76,4 +84,6 @@
for fout in context_files.values():
fout.close()

print('Number of context groups skipped due to uncommon prediction:', n_unclear_skips)

# idea: use first, second, third moments of the distributions of weights?

0 comments on commit b5898d8

Please sign in to comment.