Skip to content

Commit

Permalink
[peft] If AutoModel is wrapped with PEFT for prompt learning, then …
Browse files Browse the repository at this point in the history
…extend the attention mask (#3000)
  • Loading branch information
tomaarsen authored Oct 29, 2024
1 parent f286d9f commit 1912788
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
from transformers.utils import is_peft_available

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -350,15 +351,31 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc
output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
output_tokens = output_states[0]

features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
# If the AutoModel is wrapped with a PeftModelForFeatureExtraction, then it may have added virtual tokens
# We need to extend the attention mask to include these virtual tokens, or the pooling will fail
if is_peft_available():
from peft import PeftModelForFeatureExtraction

if (
isinstance(self.auto_model, PeftModelForFeatureExtraction)
and self.auto_model.active_peft_config.is_prompt_learning
):
batch_size = output_tokens.size(0)
attention_mask = features["attention_mask"]
prefix_attention_mask = torch.ones(
batch_size, self.auto_model.active_peft_config.num_virtual_tokens, device=attention_mask.device
)
features["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

features["token_embeddings"] = output_tokens

if self.auto_model.config.output_hidden_states and len(output_states) > 2:
all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1

hidden_states = output_states[all_layer_idx]
features.update({"all_layer_embeddings": hidden_states})
features["all_layer_embeddings"] = hidden_states

return features

Expand Down

0 comments on commit 1912788

Please sign in to comment.