Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for prompt augmentation #766

Merged
merged 6 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,27 @@ You can also create your own augmenter from scratch by importing transformations
['What I cannot creae, I do not understand.', 'What I cannot creat, I do not understand.', 'What I cannot create, I do not nderstand.', 'What I cannot create, I do nt understand.', 'Wht I cannot create, I do not understand.']
```

#### Prompt Augmentation
In additional to augmentation of regular text, you can augment prompts and then generate responses to
the augmented prompts using a large language model (LLMs). The augmentation is performed using the same
`Augmenter` as above. To generate responses, you can use your own LLM, a HuggingFace LLM, or an OpenAI LLM.
Here's an example using a pretrained HuggingFace LLM:

```python
>>> from textattack.augmentation import EmbeddingAugmenter
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> from textattack.llms import HuggingFaceLLMWrapper
>>> from textattack.prompt_augmentation import PromptAugmentationPipeline
>>> augmenter = EmbeddingAugmenter(transformations_per_example=3)
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
>>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
>>> model_wrapper = HuggingFaceLLMWrapper(model, tokenizer)
>>> pipeline = PromptAugmentationPipeline(augmenter, model_wrapper)
>>> pipeline("Classify the following piece of text as `positive` or `negative`: This movie is great!")
[('Classify the following piece of text as `positive` or `negative`: This film is great!', ['positive']), ('Classify the following piece of text as `positive` or `negative`: This movie is fabulous!', ['positive']), ('Classify the following piece of text as `positive` or `negative`: This movie is wonderful!', ['positive'])]
```


### Training Models: `textattack train`

Our model training code is available via `textattack train` to help you train LSTMs,
Expand Down
12 changes: 12 additions & 0 deletions docs/apidoc/textattack.constraints.pre_transformation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,15 @@ textattack.constraints.pre\_transformation package
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.constraints.pre_transformation.unmodifiable_indices
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.constraints.pre_transformation.unmodifiable_phrases
:members:
:undoc-members:
:show-inheritance:
19 changes: 19 additions & 0 deletions docs/apidoc/textattack.llms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
textattack.llms package
=========================

.. automodule:: textattack.llms
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.llms.huggingface_llm_wrapper
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.llms.chat_gpt_wrapper
:members:
:undoc-members:
:show-inheritance:
13 changes: 13 additions & 0 deletions docs/apidoc/textattack.prompt_augmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
textattack.prompt_augmentation package
=======================================

.. automodule:: textattack.prompt_augmentation
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.prompt_augmentation.prompt_augmentation_pipeline
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/apidoc/textattack.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ textattack package
textattack.datasets
textattack.goal_function_results
textattack.goal_functions
textattack.llms
textattack.loggers
textattack.metrics
textattack.models
textattack.prompt_augmentation
textattack.search_methods
textattack.shared
textattack.transformations
Expand Down
1 change: 1 addition & 0 deletions tests/test_command_line/update_test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This is useful for large changes, but be wary: the outputs still may
need to be manually edited to account for variance between runs.
"""

from helpers import run_command_and_get_result
from test_attack import attack_test_params
from test_augment import augment_test_params
Expand Down
31 changes: 31 additions & 0 deletions tests/test_constraints/test_pretransformation_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,34 @@ def test_stopword_modification(
set(range(len(entailment_attacked_text.words)))
- {1, 2, 3, 8, 9, 11, 16, 17, 20, 22, 25, 31, 34, 39, 40, 41, 43, 44}
)

def test_unmodifiable_indices(
self, sentence_attacked_text, entailment_attacked_text
):
constraint = textattack.constraints.pre_transformation.UnmodifiableIndices(
[4, 5]
)
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
set(range(len(sentence_attacked_text.words))) - {4, 5}
)
sentence_attacked_text = sentence_attacked_text.delete_word_at_index(2)
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
set(range(len(sentence_attacked_text.words))) - {3, 4}
)
assert constraint._get_modifiable_indices(entailment_attacked_text) == (
set(range(len(entailment_attacked_text.words))) - {4, 5}
)
entailment_attacked_text = (
entailment_attacked_text.insert_text_after_word_index(0, "two words")
)
assert constraint._get_modifiable_indices(entailment_attacked_text) == (
set(range(len(entailment_attacked_text.words))) - {6, 7}
)

def test_unmodifiable_phrases(self, sentence_attacked_text):
constraint = textattack.constraints.pre_transformation.UnmodifablePhrases(
["South Korea's", "oil", "monday"]
)
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
set(range(len(sentence_attacked_text.words))) - {0, 1, 9, 22}
)
25 changes: 25 additions & 0 deletions tests/test_prompt_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def test_prompt_augmentation_pipeline():
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from textattack.augmentation.recipes import CheckListAugmenter
from textattack.constraints.pre_transformation import UnmodifiableIndices
from textattack.llms import HuggingFaceLLMWrapper
from textattack.prompt_augmentation import PromptAugmentationPipeline

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model_wrapper = HuggingFaceLLMWrapper(model, tokenizer)

augmenter = CheckListAugmenter()

pipeline = PromptAugmentationPipeline(augmenter, model_wrapper)

prompt = "As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: Poor Ben Bratt couldn't find stardom if MapQuest emailed him point-to-point driving directions."
prompt_constraints = [UnmodifiableIndices([2, 3, 10, 12, 14])]

output = pipeline(prompt, prompt_constraints)

assert len(output) == 1
assert len(output[0]) == 2
assert "could not" in output[0][0]
assert "negative" in output[0][1]
1 change: 1 addition & 0 deletions textattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
"""

from .attack_args import AttackArgs, CommandLineAttackArgs
from .augment_args import AugmenterArgs
from .dataset_args import DatasetArgs
Expand Down
6 changes: 3 additions & 3 deletions textattack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ def filter_transformations(
uncached_texts.append(transformed_text)
else:
# promote transformed_text to the top of the LRU cache
self.constraints_cache[
(current_text, transformed_text)
] = self.constraints_cache[(current_text, transformed_text)]
self.constraints_cache[(current_text, transformed_text)] = (
self.constraints_cache[(current_text, transformed_text)]
)
if self.constraints_cache[(current_text, transformed_text)]:
filtered_texts.append(transformed_text)
filtered_texts += self._filter_transformations_uncached(
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/bae_garg_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
============================================

"""

from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import (
RepeatModification,
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/bert_attack_li_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Consider using smaller values for "max_candidates".

"""

from textattack import Attack
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/checklist_ribeiro_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
(Beyond Accuracy: Behavioral Testing of NLP models with CheckList)

"""

from textattack import Attack
from textattack.constraints.pre_transformation import RepeatModification
from textattack.goal_functions import UntargetedClassification
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/hotflip_ebrahimi_2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
(HotFlip: White-Box Adversarial Examples for Text Classification)

"""

from textattack import Attack
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.overlap import MaxWordsPerturbed
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/iga_wang_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
(Natural Language Adversarial Attacks and Defenses in Word Level)

"""

from textattack import Attack
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import StopwordModification
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/input_reduction_feng_2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
(Pathologies of Neural Models Make Interpretations Difficult)

"""

from textattack import Attack
from textattack.constraints.pre_transformation import (
RepeatModification,
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/kuleshov_2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
(Adversarial Examples for Natural Language Classification Problems)

"""

from textattack import Attack
from textattack.constraints.grammaticality.language_models import GPT2
from textattack.constraints.overlap import MaxWordsPerturbed
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/morpheus_tan_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


"""

from textattack import Attack
from textattack.constraints.pre_transformation import (
RepeatModification,
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/pruthi_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
=================================================================

"""

from textattack import Attack
from textattack.constraints.overlap import MaxWordsPerturbed
from textattack.constraints.pre_transformation import (
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/pso_zang_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
(Word-level Textual Adversarial Attacking as Combinatorial Optimization)

"""

from textattack import Attack
from textattack.constraints.pre_transformation import (
InputColumnModification,
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/pwws_ren_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
(Generating Natural Language Adversarial Examples through Probability Weighted Word Saliency)

"""

from textattack import Attack
from textattack.constraints.pre_transformation import (
RepeatModification,
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/seq2sick_cheng_2018_blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
================================================
(Seq2Sick: Evaluating the Robustness of Sequence-to-Sequence Models with Adversarial Examples)
"""

from textattack import Attack
from textattack.constraints.overlap import LevenshteinEditDistance
from textattack.constraints.pre_transformation import (
Expand Down
1 change: 0 additions & 1 deletion textattack/attack_results/successful_attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

"""


from .attack_result import AttackResult


Expand Down
1 change: 0 additions & 1 deletion textattack/augment_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
===================
"""


from dataclasses import dataclass

AUGMENTATION_RECIPE_NAMES = {
Expand Down
1 change: 0 additions & 1 deletion textattack/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Transformations and constraints can be used outside of an attack for simple NLP data augmentation with the ``Augmenter`` class that returns all possible transformations for a given string.
"""


from .augmenter import Augmenter
from .recipes import (
WordNetAugmenter,
Expand Down
1 change: 1 addition & 0 deletions textattack/augmentation/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Augmenter Class
===================
"""

import random

import tqdm
Expand Down
1 change: 1 addition & 0 deletions textattack/augmentation/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Transformations and constraints can be used for simple NLP data augmentations. Here is a list of recipes for NLP data augmentations

"""

import random

from textattack.constraints.pre_transformation import (
Expand Down
1 change: 0 additions & 1 deletion textattack/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

"""


from abc import ABC, abstractmethod
from .textattack_command import TextAttackCommand
from . import textattack_cli
1 change: 0 additions & 1 deletion textattack/commands/eval_model_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

"""


from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import dataclass

Expand Down
1 change: 0 additions & 1 deletion textattack/commands/textattack_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

"""


# !/usr/bin/env python
import argparse

Expand Down
1 change: 0 additions & 1 deletion textattack/commands/train_model_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

"""


from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

from textattack import CommandLineTrainingArgs, Trainer
Expand Down
1 change: 1 addition & 0 deletions textattack/constraints/grammaticality/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
--------------------------

"""

import lru
import nltk
from transformers import AutoModelForSequenceClassification, AutoTokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

"""


from .language_model_constraint import LanguageModelConstraint

from .google_language_model import Google1BillionWordsLanguageModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

"""


from .google_language_model import (
GoogleLanguageModel as Google1BillionWordsLanguageModel,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
All rights reserved.
"""


import os

import lru
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
--------------------------------------

"""

from collections import defaultdict

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Author: Moustafa Alzantot (malzantot@ucla.edu)
All rights reserved.
"""

import sys

from textattack.shared.utils import LazyLoader
Expand Down
Loading
Loading