-
-
Notifications
You must be signed in to change notification settings - Fork 463
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
32beaa6
commit 9826d89
Showing
7 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
nlpaug.augmenter.word\.back_translatoin | ||
======================================== | ||
|
||
.. automodule:: nlpaug.augmenter.word.back_translatoin | ||
:members: | ||
:inherited-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ Word Augmenter | |
:maxdepth: 6 | ||
|
||
./antonym | ||
./back_translation | ||
./context_word_embs | ||
./random | ||
./spelling | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
Augmenter that apply operation (word level) to textual input based on back translation. | ||
""" | ||
|
||
import string | ||
import os | ||
|
||
from nlpaug.augmenter.word import WordAugmenter | ||
import nlpaug.model.lang_models as nml | ||
|
||
BACK_TRANSLATION_MODELS = {} | ||
|
||
|
||
def init_back_translatoin_model(from_model_name, from_model_checkpt, to_model_name, to_model_checkpt, | ||
tokenzier_name, bpe_name, device, force_reload=False): | ||
global BACK_TRANSLATION_MODELS | ||
|
||
model_name = '_'.join([from_model_name, to_model_name]) | ||
if model_name in BACK_TRANSLATION_MODELS and not force_reload: | ||
return BACK_TRANSLATION_MODELS[model_name] | ||
model = nml.Fairseq(from_model_name=from_model_name, from_model_checkpt=from_model_checkpt, | ||
to_model_name=to_model_name, to_model_checkpt=to_model_checkpt, | ||
tokenzier_name=tokenzier_name, bpe_name=bpe_name, device=device) | ||
|
||
BACK_TRANSLATION_MODELS[model_name] = model | ||
return model | ||
|
||
|
||
class BackTranslationAug(WordAugmenter): | ||
# https://arxiv.org/pdf/1511.06709.pdf | ||
""" | ||
Augmenter that leverage two translation models for augmentation. For example, the source is English. This | ||
augmenter translate source to German and translating it back to English. For detail, you may visit | ||
https://towardsdatascience.com/data-augmentation-in-nlp-2801a34dfc28 | ||
:param str from_model_name: Language of your text. Veriried 'transformer.wmt19.en-de', 'transformer.wmt19.de-en', | ||
'transformer.wmt19.en-ru' and 'transformer.wmt19.ru-en' | ||
:param str to_model_name: Language for translation. Veriried 'transformer.wmt19.en-de', 'transformer.wmt19.de-en', | ||
'transformer.wmt19.en-ru' and 'transformer.wmt19.ru-en' | ||
:param str tokenizer: Default value is 'moses' | ||
:param str bpe: Default value is 'fastbpe' | ||
:param str device: Use either cpu or gpu. Default value is None, it uses GPU if having. While possible values are | ||
'cuda' and 'cpu'. | ||
:param bool force_reload: Force reload the contextual word embeddings model to memory when initialize the class. | ||
Default value is False and suggesting to keep it as False if performance is the consideration. | ||
:param str name: Name of this augmenter | ||
>>> import nlpaug.augmenter.word as naw | ||
>>> aug = naw.BackTranslationAug() | ||
""" | ||
|
||
def __init__(self, from_model_name, to_model_name, from_model_checkpt='model1.pt', to_model_checkpt='model1.pt', | ||
tokenizer='moses', bpe='fastbpe', name='BackTranslationAug', device=None, force_reload=False, verbose=0): | ||
super().__init__( | ||
# TODO: does not support include detail | ||
action='substitute', name=name, aug_p=None, aug_min=None, aug_max=None, tokenizer=None, | ||
device=device, verbose=verbose, include_detail=False) | ||
|
||
|
||
self.model = self.get_model( | ||
from_model_name=from_model_name, from_model_checkpt=from_model_checkpt, | ||
to_model_name=to_model_name, to_model_checkpt=to_model_checkpt, | ||
tokenzier_name=tokenizer, bpe_name=bpe, device=device | ||
) | ||
self.device = self.model.device | ||
|
||
def substitute(self, data): | ||
augmented_text = self.model.predict(data) | ||
return augmented_text | ||
|
||
@classmethod | ||
def get_model(cls, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt, | ||
tokenzier_name, bpe_name, device='cuda', force_reload=False): | ||
return init_back_translatoin_model(from_model_name, from_model_checkpt, | ||
to_model_name, to_model_checkpt, tokenzier_name, bpe_name, | ||
device, force_reload | ||
) | ||
|
||
@classmethod | ||
def clear_cache(cls): | ||
global BACK_TRANSLATION_MODELS | ||
BACK_TRANSLATION_MODELS = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
try: | ||
import torch | ||
except ImportError: | ||
# No installation required if not using this function | ||
pass | ||
|
||
from nlpaug.model.lang_models import LanguageModels | ||
from nlpaug.util.selection.filtering import * | ||
|
||
|
||
class Fairseq(LanguageModels): | ||
def __init__(self, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt, tokenzier_name='moses', bpe_name='fastbpe', device='cuda'): | ||
super().__init__(device, temperature=None, top_k=None, top_p=None) | ||
|
||
try: | ||
import torch | ||
import fairseq | ||
self.device = 'cuda' if device is None and torch.cuda.is_available() else device | ||
except ImportError: | ||
raise ImportError('Missed torch, fairseq libraries. Install torch by following https://pytorch.org/get-started/locally/ and fairseq by ' | ||
'https://github.com/pytorch/fairseq') | ||
|
||
self.from_model_name = from_model_name | ||
self.from_model_checkpt = from_model_checkpt | ||
self.to_model_name = to_model_name | ||
self.to_model_checkpt = to_model_checkpt | ||
self.tokenzier_name = tokenzier_name | ||
self.bpe_name = bpe_name | ||
|
||
# TODO: enahnce to support custom model. https://github.com/pytorch/fairseq/tree/master/examples/translation | ||
self.from_model = torch.hub.load( | ||
github='pytorch/fairseq', model=from_model_name, | ||
checkpoint_file=from_model_checkpt, | ||
tokenizer=tokenzier_name, bpe=bpe_name) | ||
self.to_model = torch.hub.load( | ||
github='pytorch/fairseq', model=to_model_name, | ||
checkpoint_file=to_model_checkpt, | ||
tokenizer=tokenzier_name, bpe=bpe_name) | ||
|
||
self.from_model.eval() | ||
self.to_model.eval() | ||
if self.device == 'cuda': | ||
self.from_model.cuda() | ||
self.to_model.cuda() | ||
|
||
def predict(self, text): | ||
translated_text = self.from_model.translate(text) | ||
back_translated_text = self.to_model.translate(translated_text) | ||
|
||
return back_translated_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import unittest | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
import nlpaug.augmenter.word as naw | ||
import nlpaug.model.lang_models as nml | ||
|
||
|
||
class TestBackTranslationAug(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
env_config_path = os.path.abspath(os.path.join( | ||
os.path.dirname(__file__), '..', '..', '..', '.env')) | ||
load_dotenv(env_config_path) | ||
|
||
cls.text = 'The quick brown fox jumps over the lazy dog' | ||
|
||
cls.model_names = [{ | ||
'from_model_name': 'transformer.wmt19.en-ru', | ||
'from_model_checkpt': 'model1.pt', | ||
'to_model_name': 'transformer.wmt19.ru-en', | ||
'to_model_checkpt': 'model1.pt' | ||
}, { | ||
'from_model_name': 'transformer.wmt19.en-de', | ||
'from_model_checkpt': 'model1.pt', | ||
'to_model_name': 'transformer.wmt19.de-en', | ||
'to_model_checkpt': 'model1.pt' | ||
} | ||
] | ||
|
||
def test_back_translation(self): | ||
for model_name in self.model_names: | ||
aug = naw.BackTranslationAug( | ||
from_model_name=model_name['from_model_name'], from_model_checkpt=model_name['from_model_checkpt'], | ||
to_model_name=model_name['to_model_name'], to_model_checkpt=model_name['to_model_checkpt']) | ||
augmented_text = aug.augment(self.text) | ||
aug.clear_cache() | ||
self.assertNotEqual(self.text, augmented_text) | ||
|
||
self.assertTrue(len(self.model_names) > 1) |