Skip to content

Commit

Permalink
Merge pull request #92 from ashokrajab/onnx_conversion_fix
Browse files Browse the repository at this point in the history
Modifed masking before pooling - Fixes issue in ONNX conversion
  • Loading branch information
hongjin-su authored Apr 12, 2024
2 parents e749023 + 856211d commit d92edb7
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 330 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
.idea/
/cache
/evaluation/MTEB/mteb.egg-info
/**/__pycache__
/InstructorEmbedding.egg-info
651 changes: 376 additions & 275 deletions InstructorEmbedding/instructor.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions evaluation/MTEB/examples/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import argparse
from mteb import MTEB
from InstructorEmbedding import INSTRUCTOR
from InstructorEmbedding import Instructor
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
Expand All @@ -24,7 +24,7 @@
# from functools import partialmethod
#
# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
model = INSTRUCTOR(args.model_name,cache_folder=args.cache_dir)
model = Instructor(args.model_name,cache_folder=args.cache_dir)
evaluation = MTEB(tasks=[args.task_name],task_langs=["en"])
evaluation.run(model, output_folder=args.output_dir, eval_splits=[args.split],args=args,)

Expand Down
33 changes: 15 additions & 18 deletions evaluation/MTEB/mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def evaluate(
model,
split="test",
batch_size=128,
corpus_chunk_size=None,
corpus_chunk_size=50000,
target_devices=None,
score_function="cos_sim",
**kwargs
Expand Down Expand Up @@ -708,7 +708,7 @@ def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
instruction = DEFINITIONS[self.args.prompt][self.args.task_name]['query']
if self.args.prompt:
for s in queries:
new_sentences.append([instruction, s, 0])
new_sentences.append([instruction, s])
else:
new_sentences = queries

Expand All @@ -717,7 +717,6 @@ def encode_queries(self, queries: List[str], batch_size: int, **kwargs):

def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
self.count += 1
# print('count: ',self.count)
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + ' ' + corpus["text"][i]).strip()
Expand All @@ -733,28 +732,26 @@ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs)
new_sentences = []
instruction = DEFINITIONS[self.args.prompt][self.args.task_name]['corpus']
for s in sentences:
new_sentences.append([instruction, s, 0])
# kwargs['show_progress_bar'] = False
return self.model.encode(sentences, batch_size=128, **kwargs)
new_sentences.append([instruction, s])
return self.model.encode(new_sentences, batch_size=128, **kwargs)

def encode_corpus_parallel(
self, corpus: List[Dict[str, str]], pool: Dict[str, object], batch_size: int, chunk_id: int, **kwargs
):
sentences = []
instruction = DEFINITIONS[self.args.prompt][self.args.task_name]['corpus']
if type(corpus) is dict:
sentences = [
[instruction, (corpus["title"][i] + self.sep + corpus["text"][i]).strip()]
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
for i in range(len(corpus["text"])):
sentence = corpus["text"][i].strip()
if "title" in corpus:
sentence = corpus["title"][i].strip() + self.sep + sentence
sentences.append([instruction, sentence])
else:
sentences = [
[instruction, (doc["title"] + self.sep + doc["text"]).strip()]
(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
for doc in corpus
]
for doc in corpus:
sentence = doc["text"].strip()
if "title" in doc:
sentence = doc["title"].strip() + self.sep + sentence
sentences.append([instruction, sentence])

if chunk_id is not None and chunk_id >= len(pool["processes"]):
output_queue = pool["output"]
Expand Down
2 changes: 2 additions & 0 deletions evaluation/MTEB/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
"torch",
"tqdm",
"rich",
"beir",
"evaluate==0.2.0"
],
extras_require=extras,
classifiers=[
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ sentence_transformers>=2.2.0
torch
tqdm
rich
tensorboard
49 changes: 14 additions & 35 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import transformers
from filelock import FileLock
from InstructorEmbedding import INSTRUCTOR
from InstructorEmbedding import Instructor, InstructorTransformer
from transformers import (
AutoTokenizer,
DataCollatorForSeq2Seq,
Expand All @@ -27,6 +27,9 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
from transformers.training_args import TrainingArguments

from transformers.utils import check_min_version, is_offline_mode
from torch.utils.data import Dataset, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -100,7 +103,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
cur_inputs = {
'input_ids': inputs[f'{k}_input_ids'],
'attention_mask': inputs[f'{k}_attention_mask'],
'context_masks': inputs[f'{k}_context_masks'],
'instruction_mask': inputs[f'{k}_instruction_mask'],
}
cur_results[k] = model(cur_inputs)['sentence_embedding']
embeddings_query = cur_results['query']
Expand Down Expand Up @@ -156,7 +159,6 @@ class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
Expand Down Expand Up @@ -424,13 +426,8 @@ def main():
)

# Set seed before initializing model.
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
instructor_tokenizer = InstructorTransformer(model_name_or_path=model_args.model_name_or_path, load_model=False)
tokenizer = instructor_tokenizer.tokenizer #pre-trained tokentizer

set_seed(training_args.seed)
with open(os.path.join(model_args.cache_dir, 'medi-data.json')) as f:
Expand All @@ -443,7 +440,7 @@ def main():

real_batch_size = max(training_args.per_device_train_batch_size,
training_args.per_device_train_batch_size * torch.cuda.device_count())
# print('real_batch_size: ', real_batch_size,training_args.per_device_train_batch_size,torch.cuda.device_count())

def get_examples_raw(old_examples_raw, total_n, real_batch_size):
examples_raw = []
for idx in range(0, total_n, real_batch_size):
Expand Down Expand Up @@ -485,13 +482,11 @@ def get_dataset(examples_raw):
for i in range(total_num):
cur_e = examples_raw[i]
for k in ['query','pos','neg']:
for s in cur_e[k][:-1]:
assert not '!@#$%^&**!@#$%^&**' in s
cur_e[k][-1] = str(cur_e[k][-1])
if not data_args.add_prompt_to_document:
cur_e[k][0] = ''
assert cur_e[k][0].startswith('Represent ') or cur_e[k][0]==''
examples[k].append('!@#$%^&**!@#$%^&**'.join(cur_e[k]))
examples[k].append(cur_e[k])
if not cur_e['task_id'] in task_name_map:
task_name_map[cur_e['task_id']] = task_count
task_count += 1
Expand All @@ -500,36 +495,20 @@ def get_dataset(examples_raw):

train_raw_datasets = DatasetDict({'train':Dataset.from_dict(get_dataset(train_examples_raw))})

model = INSTRUCTOR(real_name_or_path, cache_folder=model_args.cache_dir)
model = Instructor(real_name_or_path, cache_folder=model_args.cache_dir)
column_names = train_raw_datasets["train"].column_names

def preprocess_function(examples):
all_tokenized = None
for key in ['query','pos','neg']:
num = len(examples[key])
contexts = []
concatenated_input_texts = []
for local_idx in range(num):
splits = examples[key][local_idx].split('!@#$%^&**!@#$%^&**')
assert len(splits) == 2
contexts.append(splits[0])
concatenated_input_texts.append(''.join(splits))
assert isinstance(contexts[-1], str)
assert isinstance(concatenated_input_texts[-1], str)
tokenized = tokenizer(concatenated_input_texts,padding='max_length', truncation='longest_first', return_tensors="pt", max_length=data_args.max_source_length)
context_tok = tokenizer(contexts,padding='max_length', truncation='longest_first', return_tensors="pt", max_length=data_args.max_source_length)
tokenized['context_masks'] = torch.sum(context_tok['attention_mask'], dim=1)
tokenized['context_masks'] = tokenized['context_masks'] - 1
for my_idx in range(len(tokenized['context_masks'])):
if tokenized['context_masks'][my_idx] <= 1:
tokenized['context_masks'][my_idx] = 0
keys = tokenized.keys()
input_features = instructor_tokenizer.tokenize(examples[key])
keys = input_features.keys()
if all_tokenized is None:
all_tokenized = tokenized.copy()
all_tokenized = input_features.copy()
for k in keys:
all_tokenized[k] = all_tokenized[k].tolist()
for k in keys:
all_tokenized[f'{key}_{k}'] = tokenized[k].tolist()
all_tokenized[f'{key}_{k}'] = input_features[k].tolist()
all_tokenized['task_id'] = examples['task_id']
return all_tokenized

Expand Down

0 comments on commit d92edb7

Please sign in to comment.