Skip to content

Commit

Permalink
Merge pull request #63 from chiragjn/opt_gpt2
Browse files Browse the repository at this point in the history
Use `past` agrument for GPT2 to speed up decoding
  • Loading branch information
makcedward authored Nov 19, 2019
2 parents 9863867 + 0762f78 commit 264371c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
23 changes: 12 additions & 11 deletions nlpaug/augmenter/sentence/context_word_embs_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def init_context_word_embs_sentence_model(model_path, device, force_reload=False, temperature=1.0, top_k=None,
top_p=None):
top_p=None, return_past=True):
global CONTEXT_WORD_EMBS_SENTENCE_MODELS

model_name = os.path.basename(model_path)
Expand All @@ -23,16 +23,12 @@ def init_context_word_embs_sentence_model(model_path, device, force_reload=False
return CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name]

if 'xlnet' in model_path:
model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p)
model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, return_past=return_past)
elif 'gpt2' in model_path:
model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p)
model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, return_past=return_past)
else:
raise ValueError('Model name value is unexpected. Only support XLNet and GPT2 model.')

CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name] = model
return model


class ContextualWordEmbsForSentenceAug(SentenceAugmenter):
# https://arxiv.org/pdf/1707.07328.pdf
"""
Expand All @@ -52,8 +48,8 @@ class ContextualWordEmbsForSentenceAug(SentenceAugmenter):
Default value is False and suggesting to keep it as False if performance is the consideration.
:param str name: Name of this augmenter
>>> import nlpaug.augmenter.word as naw
>>> aug = naw.ContextualWordEmbsForSentenceAug()
>>> import nlpaug.augmenter.sentence as nas
>>> aug = nas.ContextualWordEmbsForSentenceAug()
"""

def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=100, top_p=None,
Expand Down Expand Up @@ -86,7 +82,8 @@ def insert(self, data):
if data is None or data == '' or data.strip() == '':
return data

max_try = 100
max_try = 30 # On average 30 should be enough to complete a sentence
past = None
augmented_text = ''

for _ in range(max_try):
Expand All @@ -95,7 +92,11 @@ def insert(self, data):
if self.model_type in ['xlnet']:
text += ' ' + self.model.MASK_TOKEN

results = self.model.predict(text, n=1)
results = self.model.predict(text, n=1, past=past)

if self.model.return_past:
results, past = results

new_word, proba = results[0]

if new_word in self.SENTENCE_SEPARATOR:
Expand Down
14 changes: 11 additions & 3 deletions nlpaug/model/lang_models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Gpt2(LanguageModels):
# https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
SUBWORD_PREFIX = 'Ġ'

def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, device=None):
def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, device=None, return_past=False):
super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p)
self.model_path = model_path

Expand All @@ -22,17 +22,21 @@ def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, d
self.model.to(self.device)
self.model.eval()

self.return_past = return_past

def id2token(self, _id):
return self.tokenizer.decode(_id, clean_up_tokenization_spaces=True).strip()

def predict(self, text, target_word=None, n=1):
def predict(self, text, target_word=None, n=1, past=None):
# Convert feature
input_idxes = self.tokenizer.encode(text)
if past is not None:
input_idxes = input_idxes[-1:]
input_idxes = torch.tensor(input_idxes, device=self.device).unsqueeze(0).repeat(1, 1)

# Prediction
with torch.no_grad():
outputs = self.model(input_idxes)
outputs = self.model(input_ids=input_idxes, past=past)
target_token_logits = outputs[0][0][-1] # GPT2 only predict last token

# Selection
Expand All @@ -41,4 +45,8 @@ def predict(self, text, target_word=None, n=1):
target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed)

results = self.pick(target_token_logits, target_word=target_word, n=n)
if self.return_past:
past = outputs[1]
results = (results, past,)

return results
13 changes: 10 additions & 3 deletions nlpaug/model/lang_models/xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class XlNet(LanguageModels):
NEW_PARAGRAPH_TOKEN = '<eop>'

def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=None, top_p=None, padding_text=None,
device=None):
device=None, return_past=False):
super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p)
self.model_path = model_path

Expand All @@ -39,13 +39,16 @@ def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=None, t
self.model.to(self.device)
self.model.eval()

self.return_past = return_past

def id2token(self, _id):
return self.tokenizer.decode(_id, clean_up_tokenization_spaces=True).strip()

def clean(self, text):
return text.replace(self.NEW_PARAGRAPH_TOKEN, '').strip()

def predict(self, text, target_word=None, n=1):
def predict(self, text, target_word=None, n=1, past=None):
# xlnet does not support `past`, instead there is `mems` which works differently
# Convert feature
input_idxes = self.tokenizer.encode(text)
concatenated_idxes = self.padding_text_idxes + input_idxes
Expand All @@ -72,4 +75,8 @@ def predict(self, text, target_word=None, n=1):
target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed)

results = self.pick(target_token_logits, target_word=target_word, n=n)
return results

if self.return_past:
results = (results, past,) # Only, for API compatibility, past is not used for xlnet

return results

0 comments on commit 264371c

Please sign in to comment.