Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ textattack.constraints.semantics.sentence\_encoders package
.. toctree::
:maxdepth: 6

textattack.constraints.semantics.sentence_encoders.bert
textattack.constraints.semantics.sentence_encoders.sentence_bert
textattack.constraints.semantics.sentence_encoders.infer_sent
textattack.constraints.semantics.sentence_encoders.universal_sentence_encoder

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
textattack.constraints.semantics.sentence\_encoders.bert package
================================================================

.. automodule:: textattack.constraints.semantics.sentence_encoders.bert
.. automodule:: textattack.constraints.semantics.sentence_encoders.sentence_bert
:members:
:undoc-members:
:show-inheritance:




.. automodule:: textattack.constraints.semantics.sentence_encoders.bert.bert
.. automodule:: textattack.constraints.semantics.sentence_encoders.sentence_bert.sbert
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
# Semantics constraints
#
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
"sbert": "textattack.constraints.semantics.sentence_encoders.SBERT",
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack_recipes/a2t_yoo_2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import BERT
from textattack.constraints.semantics.sentence_encoders import SBERT
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
Expand Down Expand Up @@ -49,7 +49,7 @@ def build(model_wrapper, mlm=False):
constraints.append(input_column_modification)
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
sent_encoder = BERT(
sent_encoder = SBERT(
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
)
constraints.append(sent_encoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .sentence_encoder import SentenceEncoder

from .bert import BERT
from .sentence_bert import SBERT
from .infer_sent import InferSent
from .thought_vector import ThoughtVector
from .universal_sentence_encoder import (
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
sBERT
^^^^^^^
"""

from .sbert import SBERT
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)


class BERT(SentenceEncoder):
class SBERT(SentenceEncoder):
"""Constraint using similarity between sentence encodings of x and x_adv
where the text embeddings are created using BERT, trained on NLI data, and
fine- tuned on the STS benchmark dataset.
Expand Down
4 changes: 2 additions & 2 deletions textattack/metrics/quality_metrics/sentence_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
"""

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.constraints.semantics.sentence_encoders import BERT
from textattack.constraints.semantics.sentence_encoders import SBERT
from textattack.metrics import Metric


class SBERTMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.use_obj = SBERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}
Expand Down