-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add paddle nv-embed-v1 #8785
Merged
sijunhe
merged 2 commits into
PaddlePaddle:develop
from
Li-Z-Q:add-paddle-nv-embed-mteb
Jul 28, 2024
Merged
add paddle nv-embed-v1 #8785
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,15 +15,28 @@ | |
import argparse | ||
import logging | ||
|
||
import mteb | ||
import paddle | ||
from evaluation.mteb.mteb_models_nv import NVEncodeModel | ||
from mteb import MTEB | ||
from mteb_models import EncodeModel | ||
|
||
from paddlenlp.transformers import AutoModel, AutoTokenizer | ||
from paddlenlp.peft import LoRAConfig, LoRAModel | ||
from paddlenlp.transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer | ||
|
||
|
||
def get_model(peft_model_name, base_model_name): | ||
if peft_model_name is not None: | ||
raise NotImplementedError("PEFT model is not supported yet") | ||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, dtype="bfloat16") | ||
lora_config = LoRAConfig.from_pretrained(peft_model_name) | ||
lora_config.merge_weights = True | ||
lora_weights = paddle.load(peft_model_name + "/lora_model_state.pdparams") | ||
k = list(lora_weights.keys())[0] | ||
assert k.startswith( | ||
"llama." | ||
), f"You Must Manually Replace 'model' to 'llama'. Please Refer to do_replace_model_llama.py" | ||
hf_model = LoRAModel.from_pretrained(base_model, peft_model_name, lora_config=lora_config, dtype="bfloat16") | ||
return hf_model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hf_model -> model There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
else: | ||
base_model = AutoModel.from_pretrained(base_model_name) | ||
return base_model | ||
|
@@ -67,39 +80,58 @@ def get_args(): | |
logging.basicConfig(level=logging.INFO) | ||
logger.info("Args: {}".format(args)) | ||
|
||
model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) | ||
assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token" | ||
token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token, "pad_token": tokenizer.pad_token} | ||
tokenizer.pad_token = token_dict[args.pad_token] | ||
|
||
assert args.padding_side in [ | ||
"right", | ||
"left", | ||
], f"padding_side should be either 'right' or 'left', but got {args.padding_side}" | ||
assert not ( | ||
args.padding_side == "left" and args.pooling_method == "cls" | ||
), "Padding 'left' is not supported for pooling method 'cls'" | ||
tokenizer.padding_side = args.padding_side | ||
|
||
assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}" | ||
assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}" | ||
tokenizer.add_bos_token = bool(args.add_bos_token) | ||
tokenizer.add_eos_token = bool(args.add_eos_token) | ||
|
||
encode_model = EncodeModel( | ||
model=model, | ||
tokenizer=tokenizer, | ||
pooling_method=args.pooling_method, | ||
query_instruction=args.query_instruction, | ||
document_instruction=args.document_instruction, | ||
eval_batch_size=args.eval_batch_size, | ||
max_seq_length=args.max_seq_length, | ||
) | ||
if "NV-Embed" in args.base_model_name_or_path: | ||
logger.info("Using NV-Embed") | ||
|
||
query_prefix = "Instruct: " + args.query_instruction + "\nQuery: " | ||
passage_prefix = "" | ||
|
||
if args.task_name == "QuoraRetrieval": | ||
assert args.document_instruction != "document: ", f"QuoraRetrieval requires a document instruction" | ||
passage_prefix = "Instruct: " + args.document_instruction + "\nQuery: " # because this is STS task | ||
|
||
encode_model = NVEncodeModel.from_pretrained( | ||
args.base_model_name_or_path, | ||
tokenizer_path=args.base_model_name_or_path, | ||
eval_batch_size=args.eval_batch_size, | ||
query_instruction=query_prefix, | ||
document_instruction=passage_prefix, | ||
dtype="float16", | ||
) | ||
encode_model.eval() | ||
|
||
else: | ||
model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path) | ||
|
||
assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}" | ||
assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}" | ||
tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) | ||
assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token" | ||
token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token} | ||
tokenizer.pad_token = token_dict[args.pad_token] | ||
assert args.padding_side in [ | ||
"right", | ||
"left", | ||
], f"padding_side should be either 'right' or 'left', but got {args.padding_side}" | ||
assert not ( | ||
args.padding_side == "left" and args.pooling_method == "cls" | ||
), "Padding 'left' is not supported for pooling method 'cls'" | ||
tokenizer.padding_side = args.padding_side | ||
tokenizer.add_bos_token = bool(args.add_bos_token) | ||
tokenizer.add_eos_token = bool(args.add_eos_token) | ||
|
||
encode_model = EncodeModel( | ||
model=model, | ||
tokenizer=tokenizer, | ||
pooling_method=args.pooling_method, | ||
query_instruction=args.query_instruction, | ||
document_instruction=args.document_instruction, | ||
eval_batch_size=args.eval_batch_size, | ||
max_seq_length=args.max_seq_length, | ||
) | ||
|
||
logger.info("Ready to eval") | ||
evaluation = MTEB(tasks=[args.task_name]) | ||
evaluation = MTEB(tasks=mteb.get_tasks(tasks=[args.task_name])) | ||
evaluation.run( | ||
encode_model, | ||
output_folder=f"{args.output_folder}/{args.task_name}/{args.pooling_method}", | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个NV-Embed-v1 是怎么得到的呢?从torch 转过来的吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,陆老师发您的文件就是从torch转过来的paddle版本的NV-Embed-v1模型权重