Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decoder text embedding model in the examples/transformers_text.py #333

Merged
merged 17 commits into from
Jan 11, 2024
Prev Previous commit
Next Next commit
Update
  • Loading branch information
zechengz committed Jan 9, 2024
commit ff9d68932af5ac821acfa4115383f49f0d52fce8
45 changes: 18 additions & 27 deletions examples/transformers_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@
"intfloat/e5-mistral-7b-instruct",
],
)
parser.add_argument("--pooling", type=str, default="mean",
choices=["mean", "cls", "last"])
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()

Expand All @@ -106,7 +104,7 @@


class TextToEmbedding:
def __init__(self, model: str, pooling: str, device: torch.device):
def __init__(self, model: str, device: torch.device):
self.model_name = model
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model)
Expand All @@ -118,15 +116,13 @@ def __init__(self, model: str, pooling: str, device: torch.device):
).to(device)
else:
self.model = AutoModel.from_pretrained(model).to(device)
self.pooling = pooling
self.pooling = "mean"

def __call__(self, sentences: List[str]) -> Tensor:
if self.model_name == "intfloat/e5-mistral-7b-instruct":
sentences = [
get_detailed_instruct(
"Retrieve relevant knowledge and embeddings.", sentence)
for sentence in sentences
]
sentences = [(f"Instruct: Retrieve relevant knowledge and "
f"embeddings.\nQuery: {sentence}")
for sentence in sentences]
max_length = 4096
inputs = self.tokenizer(
sentences,
Expand Down Expand Up @@ -156,12 +152,15 @@ def __call__(self, sentences: List[str]) -> Tensor:
if isinstance(inputs[key], Tensor):
inputs[key] = inputs[key].to(self.device)
out = self.model(**inputs)

# [batch_size, max_length or batch_max_length]
# Value is either one or zero, where zero means that
# the token is not attended to other tokens.
mask = inputs["attention_mask"]

if self.pooling == "mean":
return (mean_pooling(out.last_hidden_state.detach(),
mask).squeeze(1).cpu())
elif self.pooling == "cls":
return out.last_hidden_state[:, 0, :].detach().cpu()
elif self.pooling == "last":
zechengz marked this conversation as resolved.
Show resolved Hide resolved
return last_pooling(out.last_hidden_state,
mask).detach().cpu().to(torch.float32)
Expand All @@ -179,12 +178,10 @@ class TextToEmbeddingFinetune(torch.nn.Module):
model (str): Model name to load by using :obj:`transformers`,
such as :obj:`distilbert-base-uncased` and
:obj:`sentence-transformers/all-distilroberta-v1`.
pooling (str): Pooling strategy to pool context embeddings into
sentence level embedding. (default: :obj:`'mean'`)
lora (bool): Whether using LoRA to finetune the text model.
(default: :obj:`False`)
"""
def __init__(self, model: str, pooling: str = "mean", lora: bool = False):
def __init__(self, model: str, lora: bool = False):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model)
Expand All @@ -208,7 +205,6 @@ def __init__(self, model: str, pooling: str = "mean", lora: bool = False):
target_modules=target_modules,
)
self.model = get_peft_model(self.model, peft_config)
self.pooling = pooling

def forward(self, feat: dict[str, MultiNestedTensor]) -> Tensor:
# Pad [batch_size, 1, *] into [batch_size, 1, batch_max_seq_len], then,
Expand All @@ -223,20 +219,15 @@ def forward(self, feat: dict[str, MultiNestedTensor]) -> Tensor:
out = self.model(input_ids=input_ids, attention_mask=mask)

# Return value has the shape [batch_size, 1, text_model_out_channels]
if self.pooling == "mean":
return mean_pooling(out.last_hidden_state, mask)
elif self.pooling == "cls":
return out.last_hidden_state[:, 0, :].unsqueeze(1)
else:
raise ValueError(f"{self.pooling} is not supported.")
return mean_pooling(out.last_hidden_state, mask)
zechengz marked this conversation as resolved.
Show resolved Hide resolved

def tokenize(self, sentences: List[str]) -> TextTokenizationOutputs:
# Tokenize batches of sentences
return self.tokenizer(sentences, truncation=True, padding=True,
return_tensors='pt')


def mean_pooling(last_hidden_state: Tensor, attention_mask) -> Tensor:
def mean_pooling(last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor:
input_mask_expanded = (attention_mask.unsqueeze(-1).expand(
last_hidden_state.size()).float())
embedding = torch.sum(
Expand All @@ -248,8 +239,11 @@ def mean_pooling(last_hidden_state: Tensor, attention_mask) -> Tensor:
def last_pooling(last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
# Whether all samples in the mini-batch has
# the last token attend to other tokens.
return last_hidden_state[:, -1]
else:
# Find the last token that attends to previous tokens.
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
return last_hidden_state[
Expand All @@ -270,8 +264,7 @@ def get_detailed_instruct(task_description: str, query: str) -> str:

# Prepare text columns
if not args.finetune:
text_encoder = TextToEmbedding(model=args.model, pooling=args.pooling,
device=device)
text_encoder = TextToEmbedding(model=args.model, device=device)
text_stype = torch_frame.text_embedded
kwargs = {
"text_stype":
Expand All @@ -280,9 +273,7 @@ def get_detailed_instruct(task_description: str, query: str) -> str:
TextEmbedderConfig(text_embedder=text_encoder, batch_size=10),
}
else:
text_encoder = TextToEmbeddingFinetune(model=args.model,
pooling=args.pooling,
lora=args.lora)
text_encoder = TextToEmbeddingFinetune(model=args.model, lora=args.lora)
text_stype = torch_frame.text_tokenized
kwargs = {
"text_stype":
Expand Down
Loading