Skip to content

Commit

Permalink
Began work on data loader.
Browse files Browse the repository at this point in the history
Fixed bug with reversing back half of context.
Added <<end_id>> subtokens so that separations are clear between subtokens of adjacent tokens.
  • Loading branch information
ben-baran committed Aug 18, 2018
1 parent b5898d8 commit 8bb267f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
3 changes: 1 addition & 2 deletions create_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import re
from utils import ComboVocab, to_subtokenized_list, subtokenize


n_ctx = 64 # size of context per side. This is decided earlier in the processing pipeline.
combo_vocab = ComboVocab()

Expand Down Expand Up @@ -70,7 +69,7 @@

for context in data['usage']:
context_a = to_subtokenized_list(context[:64])[-64:]
context_b = to_subtokenized_list(context[129:64:-1])[-64:]
context_b = to_subtokenized_list(context[129:64])[:65:-1]
fout.write(struct.pack('<64I', *combo_vocab.to_ids(context_a)))
fout.write(struct.pack('<64I', *combo_vocab.to_ids(context_b)))

Expand Down
8 changes: 7 additions & 1 deletion notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## General
* Tokens are split into subtokens so that it's agnostic to whether you use `camelCase` or `c_style` naming conventions. It's also invariant to case after splitting into subtokens. This allows you to unify the tokens with either style, but it does lose _some_ information.
* When loading in contexts, we sample by size of the file. This is proportional to the number of contexts in the file, which will oversample cases where there are a high number of contexts. I'm still not sure if this is preferable.
* `<<start_id>>` and `<<end_id>>` are used to guide seq2seq, but we also need `<<end_id>>` to distinguish `ClassName varName` from `classNameVarName`.

## On Processing
* For the future, make sure lone underscores appear as their own characters.
* Also, the current scheme makes something like `_0` appear to be an underscore followed by an integer literal.
* The current scheme makes something like `_0` appear to be an underscore followed by an integer literal.
* We don't include predicting `<unk>` tokens, since they make up only around 3-4% of the dataset and it would not make the network output any useful information.
49 changes: 45 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import re
import pickle
import mxnet.contrib.text as mtext
import glob
import os
import numpy as np
import struct

kw_or_builtin = set(('abstract continue for new switch assert default goto package synchronized '
'boolean do if private this break double implements protected throw byte else '
Expand Down Expand Up @@ -31,12 +35,13 @@ def to_subtokenized_list(l):
stized.append(item)
else:
stized.extend(subtokenize(item))
stized.append('<<end_id>>')
else:
stized.append(item)
return stized

class ComboVocab:
literal_constants = ['char', 'string', 'float', 'double', 'hex_int', 'bin_int', 'int']
literal_constants = ['char', 'string', 'float', 'double', 'hex_int', 'bin_int', 'int', 'start_id', 'end_id']
const_to_literal = {s:i for i, s in enumerate(literal_constants)}

def __init__(self, counters_fname = 'data/counters.pkl', min_for_vocabulary = 4096):
Expand All @@ -56,7 +61,10 @@ def to_ids(self, l):
if index == 0:
new_index = self.other_vocab.to_indices(l[i])
if new_index == 0: # tests for literals
if l[i][0] == "'":
removed_brackets = l[i][2:-2]
if removed_brackets in ComboVocab.const_to_literal:
new_index = self.n_others + ComboVocab.const_to_literal[removed_brackets]
elif l[i][0] == "'":
new_index = self.n_others + ComboVocab.const_to_literal['char']
elif l[i][0] == '"':
new_index = self.n_others + ComboVocab.const_to_literal['string']
Expand All @@ -82,13 +90,46 @@ def to_ids(self, l):
indices[i] -= 1
return indices

def to_tokens(self, l): # non-reversible, i.e. to_ids(to_tokens(l)) will probably not work
def to_tokens(self, l):
translation = [None for x in l]
for i, x in enumerate(l):
if x >= self.n_subtoks + self.n_others:
translation[i] = '<<' + ComboVocab.literal_constants[x - self.n_subtoks - self.n_others] + '_literal>>'
translation[i] = '<<' + ComboVocab.literal_constants[x - self.n_subtoks - self.n_others] + '>>'
elif x >= self.n_subtoks:
translation[i] = self.other_vocab.to_tokens(x - self.n_subtoks)
else:
translation[i] = self.subtok_vocab.to_tokens(x + 1)
return translation

class ContextLoader:
def __init__(self, folder_name = 'data/train_tmp/', batch_size = 32):
self.batch_size = batch_size
self.context_files = []
self.context_props = []
for filename in os.listdir(folder_name):
if filename[-4:] != '.bin':
continue
n_contexts = int(filename.split('.')[0])
full_path = folder_name + filename

# we get the size for proportionally sampling the files
size = os.path.getsize(full_path)
self.context_props.append(size)
self.context_files.append((n_contexts, open(full_path, 'rb')))
total_size = sum(self.context_props)
self.context_props = np.array([size / total_size for size in self.context_props])


def get_batch(self):
# returns random n_contexts of [pre_sequence, post_sequence, input_vars, output_vars]
# input vars is something like [<BEGIN> a b c]
# output vars is something like [a b c <END>]
# if one in the batch is longer than others, pad the rest
choice = np.random.choice(len(self.context_files), p = self.context_props)
n_contexts, fin = self.context_files[choice]
predict_subtokens = struct.unpack('<8I', fin.read(32))
pre_contexts, post_contexts = [], []
for context_n in range(n_contexts):
pre_contexts.append(struct.unpack('<64I', fin.read(256)))
post_contexts.append(struct.unpack('<64I', fin.read(256)))
return predict_subtokens, pre_contexts, post_contexts

0 comments on commit 8bb267f

Please sign in to comment.