diff --git a/nlpaug/augmenter/sentence/context_word_embs_sentence.py b/nlpaug/augmenter/sentence/context_word_embs_sentence.py index baeb44f..6c92e46 100755 --- a/nlpaug/augmenter/sentence/context_word_embs_sentence.py +++ b/nlpaug/augmenter/sentence/context_word_embs_sentence.py @@ -12,7 +12,7 @@ def init_context_word_embs_sentence_model(model_path, device, force_reload=False, temperature=1.0, top_k=None, - top_p=None): + top_p=None, return_past=True): global CONTEXT_WORD_EMBS_SENTENCE_MODELS model_name = os.path.basename(model_path) @@ -23,16 +23,12 @@ def init_context_word_embs_sentence_model(model_path, device, force_reload=False return CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name] if 'xlnet' in model_path: - model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p) + model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, return_past=return_past) elif 'gpt2' in model_path: - model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p) + model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, return_past=return_past) else: raise ValueError('Model name value is unexpected. Only support XLNet and GPT2 model.') - CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name] = model - return model - - class ContextualWordEmbsForSentenceAug(SentenceAugmenter): # https://arxiv.org/pdf/1707.07328.pdf """ @@ -52,8 +48,8 @@ class ContextualWordEmbsForSentenceAug(SentenceAugmenter): Default value is False and suggesting to keep it as False if performance is the consideration. :param str name: Name of this augmenter - >>> import nlpaug.augmenter.word as naw - >>> aug = naw.ContextualWordEmbsForSentenceAug() + >>> import nlpaug.augmenter.sentence as nas + >>> aug = nas.ContextualWordEmbsForSentenceAug() """ def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=100, top_p=None, @@ -86,7 +82,8 @@ def insert(self, data): if data is None or data == '' or data.strip() == '': return data - max_try = 100 + max_try = 30 # On average 30 should be enough to complete a sentence + past = None augmented_text = '' for _ in range(max_try): @@ -95,7 +92,11 @@ def insert(self, data): if self.model_type in ['xlnet']: text += ' ' + self.model.MASK_TOKEN - results = self.model.predict(text, n=1) + results = self.model.predict(text, n=1, past=past) + + if self.model.return_past: + results, past = results + new_word, proba = results[0] if new_word in self.SENTENCE_SEPARATOR: diff --git a/nlpaug/model/lang_models/gpt2.py b/nlpaug/model/lang_models/gpt2.py index f016525..b5cf371 100755 --- a/nlpaug/model/lang_models/gpt2.py +++ b/nlpaug/model/lang_models/gpt2.py @@ -12,7 +12,7 @@ class Gpt2(LanguageModels): # https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf SUBWORD_PREFIX = 'Ġ' - def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, device=None): + def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, device=None, return_past=False): super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p) self.model_path = model_path @@ -22,17 +22,21 @@ def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, d self.model.to(self.device) self.model.eval() + self.return_past = return_past + def id2token(self, _id): return self.tokenizer.decode(_id, clean_up_tokenization_spaces=True).strip() - def predict(self, text, target_word=None, n=1): + def predict(self, text, target_word=None, n=1, past=None): # Convert feature input_idxes = self.tokenizer.encode(text) + if past is not None: + input_idxes = input_idxes[-1:] input_idxes = torch.tensor(input_idxes, device=self.device).unsqueeze(0).repeat(1, 1) # Prediction with torch.no_grad(): - outputs = self.model(input_idxes) + outputs = self.model(input_ids=input_idxes, past=past) target_token_logits = outputs[0][0][-1] # GPT2 only predict last token # Selection @@ -41,4 +45,8 @@ def predict(self, text, target_word=None, n=1): target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed) results = self.pick(target_token_logits, target_word=target_word, n=n) + if self.return_past: + past = outputs[1] + results = (results, past,) + return results diff --git a/nlpaug/model/lang_models/xlnet.py b/nlpaug/model/lang_models/xlnet.py index 86fcd93..d589e39 100755 --- a/nlpaug/model/lang_models/xlnet.py +++ b/nlpaug/model/lang_models/xlnet.py @@ -27,7 +27,7 @@ class XlNet(LanguageModels): NEW_PARAGRAPH_TOKEN = '' def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=None, top_p=None, padding_text=None, - device=None): + device=None, return_past=False): super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p) self.model_path = model_path @@ -39,13 +39,16 @@ def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=None, t self.model.to(self.device) self.model.eval() + self.return_past = return_past + def id2token(self, _id): return self.tokenizer.decode(_id, clean_up_tokenization_spaces=True).strip() def clean(self, text): return text.replace(self.NEW_PARAGRAPH_TOKEN, '').strip() - def predict(self, text, target_word=None, n=1): + def predict(self, text, target_word=None, n=1, past=None): + # xlnet does not support `past`, instead there is `mems` which works differently # Convert feature input_idxes = self.tokenizer.encode(text) concatenated_idxes = self.padding_text_idxes + input_idxes @@ -72,4 +75,8 @@ def predict(self, text, target_word=None, n=1): target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed) results = self.pick(target_token_logits, target_word=target_word, n=n) - return results \ No newline at end of file + + if self.return_past: + results = (results, past,) # Only, for API compatibility, past is not used for xlnet + + return results