Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add paddle nv-embed-v1 #8785

Merged
merged 2 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions legacy/pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

## 安装

推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以cuda11.7的paddle为例,安装命令如下:
推荐安装 gpu 版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以 cuda11.7的 paddle 为例,安装命令如下:

```
conda install nccl -c conda-forge
conda install paddlepaddle-gpu==2.6.1 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge
```
安装其他依赖:
```
pip install git+https://github.com/PaddlePaddle/PaddleNLP.git@develop
pip install -r requirements.txt
```

下载DuReader-Retrieval中文数据集
下载 DuReader-Retrieval 中文数据集

```
cd data
Expand Down Expand Up @@ -42,34 +43,34 @@ python train.py --do_train \
--use_matryoshka
```

- `model_name_or_path`: 选择预训练模型,可选rocketqa-zh-base-query-encoder
- `model_name_or_path`: 选择预训练模型,可选 rocketqa-zh-base-query-encoder
- `output_dir`: 模型保存路径
- `train_data`: 训练数据集路径,这里使用的是dureader中文数据集
- `overwrite_output_dir`: 是否覆盖模型保存路径,默认为False
- `fine_tune_type`: 训练模式,可选sft和lora, bitfit等策略
- `sentence_pooling_method`: 句子池化方法,可选cls和mean, cls为CLS层,mean为平均池化
- `train_data`: 训练数据集路径,这里使用的是 dureader 中文数据集
- `overwrite_output_dir`: 是否覆盖模型保存路径,默认为 False
- `fine_tune_type`: 训练模式,可选 sft 和 lora, bitfit 等策略
- `sentence_pooling_method`: 句子池化方法,可选 cls 和 mean, cls 为 CLS 层,mean 为平均池化
- `num_train_epochs`: 训练轮数
- `per_device_train_batch_size`: 单卡训练batch大小
- `per_device_train_batch_size`: 单卡训练 batch 大小
- `learning_rate`: 学习率
- `train_group_size`: 每个训练集正负样本的数据,默认为8,例如train_group_size=4,则每个训练集包含1个正样本和3个负样本
- `train_group_size`: 每个训练集正负样本的数据,默认为8,例如 train_group_size=4,则每个训练集包含1个正样本和3个负样本
- `max_example_num_per_dataset`: 每个训练集的最大样本数,默认为100000000
- `recompute`: 是否重新计算,默认为False
- `query_max_len`: query的最大长度,默认为32
- `query_instruction_for_retrieval`: query的检索指令,默认为None
- `passage_instruction_for_retrieval`: passage的检索指令,默认为None
- `passage_max_len`: passage的最大长度,默认为512
- `use_matryoshka`: 是否使用俄罗斯套娃策略(matryoshka),默认为False
- `recompute`: 是否重新计算,默认为 False
- `query_max_len`: query 的最大长度,默认为32
- `query_instruction_for_retrieval`: query 的检索指令,默认为 None
- `passage_instruction_for_retrieval`: passage 的检索指令,默认为 None
- `passage_max_len`: passage 的最大长度,默认为512
- `use_matryoshka`: 是否使用俄罗斯套娃策略(matryoshka),默认为 False
- `matryoshka_dims`: 俄罗斯套娃策略的维度,默认为[64, 128, 256, 512, 768]
- `matryoshka_loss_weights`: 俄罗斯套娃策略的损失权重,默认为[1, 1, 1, 1, 1]
- `use_inbatch_neg`: 是否使用in batch negatives策略,默认为False
- `use_flash_attention`: 是否使用flash attention,默认为False
- `temperature`: in batch negatives策略的temperature参数,默认为0.02
- `negatives_cross_device`: 跨设备in batch negatives策略,默认为False
- `margin`: in batch negatives策略的margin参数,默认为0.2
- `use_inbatch_neg`: 是否使用 in batch negatives 策略,默认为 False
- `use_flash_attention`: 是否使用 flash attention,默认为 False
- `temperature`: in batch negatives 策略的 temperature 参数,默认为0.02
- `negatives_cross_device`: 跨设备 in batch negatives 策略,默认为 False
- `margin`: in batch negatives 策略的 margin 参数,默认为0.2

### 多卡训练

单卡训练效率过低,batch_size较小,建议使用多卡训练,对于对比学习训练推荐使用大batch_size,多卡训练,示例命令如下:
单卡训练效率过低,batch_size 较小,建议使用多卡训练,对于对比学习训练推荐使用大 batch_size,多卡训练,示例命令如下:

```
python -m paddle.distributed.launch --gpus "1,2,3,4" train.py --do_train \
Expand Down Expand Up @@ -100,21 +101,42 @@ python evaluation/benchmarks.py --model_type bert \
--query_max_length 64 \
--passage_max_length 512 \
```
- `model_type`: 模型的类似,可选bert或roberta等等
- `query_model`: query向量模型的路径
- `passage_model`: passage向量模型的路径
- `query_max_length`: query的最大长度
- `passage_max_length`: passage的最大长度
- `evaluate_all`: 是否评估所有的checkpoint,默认为False,即只评估指定的checkpoint
- `model_type`: 模型的类似,可选 bert 或 roberta 等等
- `query_model`: query 向量模型的路径
- `passage_model`: passage 向量模型的路径
- `query_max_length`: query 的最大长度
- `passage_max_length`: passage 的最大长度
- `evaluate_all`: 是否评估所有的 checkpoint,默认为 False,即只评估指定的 checkpoint
- `checkpoint_dir`: 与`evaluate_all`一起使用


## MTEB评估
## MTEB 评估
[MTEB](https://github.com/embeddings-benchmark/mteb)
是一个大规模文本嵌入评测基准,包含了丰富的向量检索评估任务和数据集。
本仓库主要面向其中的中英文检索任务(Retrieval),并以SciFact数据集作为主要示例
本仓库主要面向其中的中英文检索任务(Retrieval),并以 SciFact 数据集作为主要示例

评估RepLLaMA向量检索模型([repllama-v1-7b-lora-passage](https://huggingface.co/castorini/repllama-v1-7b-lora-passage)):
评估 NV-Embed 向量检索模型([NV-Embed-v1](https://huggingface.co/nvidia/NV-Embed-v1)):
```
export CUDA_VISIBLE_DEVICES=0
python eval_mteb.py \
--base_model_name_or_path NV-Embed-v1 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个NV-Embed-v1 是怎么得到的呢?从torch 转过来的吗?

Copy link
Contributor Author

@Li-Z-Q Li-Z-Q Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,陆老师发您的文件就是从torch转过来的paddle版本的NV-Embed-v1模型权重

--output_folder en_results/nv-embed-v1 \
--query_instruction "Given a claim, find documents that refute the claim" \
--task_name 'SciFact' \
--eval_batch_size 8
```
结果文件保存在`en_results/nv-embed-v1/SciFact/last/no_model_name_available/no_revision_available/SciFact.json`,包含以下类似的评估结果:
```
'ndcg_at_1': 0.67667,
'ndcg_at_3': 0.73826,
'ndcg_at_5': 0.76662,
'ndcg_at_10': 0.783,
'ndcg_at_20': 0.7936,
'ndcg_at_100': 0.80206,
'ndcg_at_1000': 0.80444
```

评估 RepLLaMA 向量检索模型([repllama-v1-7b-lora-passage](https://huggingface.co/castorini/repllama-v1-7b-lora-passage)):
```
export CUDA_VISIBLE_DEVICES=0
python evaluation/mteb/eval_mteb.py \
Expand Down Expand Up @@ -143,7 +165,7 @@ python evaluation/mteb/eval_mteb.py \
'ndcg_at_1000': 0.7794
```

评估BGE向量检索模型([bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)):
评估 BGE 向量检索模型([bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)):
```
export CUDA_VISIBLE_DEVICES=0
python evaluation/mteb/eval_mteb.py \
Expand Down Expand Up @@ -174,15 +196,15 @@ python evaluation/mteb/eval_mteb.py \
可支持配置的参数:
- `base_model_name_or_path`: 模型名称或路径
- `output_folder`: 结果文件存储路径
- `task_name`:任务(数据集)名称,如SciFact
- `task_split`:测试查询集合,如test或dev
- `query_instruction`:查询前添加的提示文本,如'query: '或None
- `document_instruction`:文档前添加的提示文本,如'passage: '或None
- `pooling_method`:获取表示的方式,last表示取最后token,mean表示取平均,cls表示取`[CLS]`token
- `task_name`:任务(数据集)名称,如 SciFact
- `task_split`:测试查询集合,如 test 或 dev
- `query_instruction`:查询前添加的提示文本,如'query: '或 None
- `document_instruction`:文档前添加的提示文本,如'passage: '或 None
- `pooling_method`:获取表示的方式,last 表示取最后 token,mean 表示取平均,cls 表示取`[CLS]`token
- `max_seq_length`: 最大序列长度
- `eval_batch_size`: 模型预测的批次大小(单个GPU
- `pad_token`:设置padding的token,可取unk_token、eos_token或pad_token
- `padding_side`:设置padding的位置,可取left或right
- `eval_batch_size`: 模型预测的批次大小(单个 GPU
- `pad_token`:设置 padding 的 token,可取 unk_token、eos_token 或 pad_token
- `padding_side`:设置 padding 的位置,可取 left 或 right
- `add_bos_token`:是否添加起始符,0表示不添加,1表示添加
- `add_eos_token`:是否添加结束符,0表示不添加,1表示添加

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,28 @@
import argparse
import logging

import mteb
import paddle
from evaluation.mteb.mteb_models_nv import NVEncodeModel
from mteb import MTEB
from mteb_models import EncodeModel

from paddlenlp.transformers import AutoModel, AutoTokenizer
from paddlenlp.peft import LoRAConfig, LoRAModel
from paddlenlp.transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer


def get_model(peft_model_name, base_model_name):
if peft_model_name is not None:
raise NotImplementedError("PEFT model is not supported yet")
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, dtype="bfloat16")
lora_config = LoRAConfig.from_pretrained(peft_model_name)
lora_config.merge_weights = True
lora_weights = paddle.load(peft_model_name + "/lora_model_state.pdparams")
k = list(lora_weights.keys())[0]
assert k.startswith(
"llama."
), f"You Must Manually Replace 'model' to 'llama'. Please Refer to do_replace_model_llama.py"
model = LoRAModel.from_pretrained(base_model, peft_model_name, lora_config=lora_config, dtype="bfloat16")
return model
else:
base_model = AutoModel.from_pretrained(base_model_name)
return base_model
Expand Down Expand Up @@ -67,39 +80,58 @@ def get_args():
logging.basicConfig(level=logging.INFO)
logger.info("Args: {}".format(args))

model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path)
assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token"
token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token, "pad_token": tokenizer.pad_token}
tokenizer.pad_token = token_dict[args.pad_token]

assert args.padding_side in [
"right",
"left",
], f"padding_side should be either 'right' or 'left', but got {args.padding_side}"
assert not (
args.padding_side == "left" and args.pooling_method == "cls"
), "Padding 'left' is not supported for pooling method 'cls'"
tokenizer.padding_side = args.padding_side

assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}"
assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}"
tokenizer.add_bos_token = bool(args.add_bos_token)
tokenizer.add_eos_token = bool(args.add_eos_token)

encode_model = EncodeModel(
model=model,
tokenizer=tokenizer,
pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
)
if "NV-Embed" in args.base_model_name_or_path:
logger.info("Using NV-Embed")

query_prefix = "Instruct: " + args.query_instruction + "\nQuery: "
passage_prefix = ""

if args.task_name == "QuoraRetrieval":
assert args.document_instruction != "document: ", f"QuoraRetrieval requires a document instruction"
passage_prefix = "Instruct: " + args.document_instruction + "\nQuery: " # because this is STS task

encode_model = NVEncodeModel.from_pretrained(
args.base_model_name_or_path,
tokenizer_path=args.base_model_name_or_path,
eval_batch_size=args.eval_batch_size,
query_instruction=query_prefix,
document_instruction=passage_prefix,
dtype="float16",
)
encode_model.eval()

else:
model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path)

assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}"
assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}"
tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path)
assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token"
token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token}
tokenizer.pad_token = token_dict[args.pad_token]
assert args.padding_side in [
"right",
"left",
], f"padding_side should be either 'right' or 'left', but got {args.padding_side}"
assert not (
args.padding_side == "left" and args.pooling_method == "cls"
), "Padding 'left' is not supported for pooling method 'cls'"
tokenizer.padding_side = args.padding_side
tokenizer.add_bos_token = bool(args.add_bos_token)
tokenizer.add_eos_token = bool(args.add_eos_token)

encode_model = EncodeModel(
model=model,
tokenizer=tokenizer,
pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
)

logger.info("Ready to eval")
evaluation = MTEB(tasks=[args.task_name])
evaluation = MTEB(tasks=mteb.get_tasks(tasks=[args.task_name]))
evaluation.run(
encode_model,
output_folder=f"{args.output_folder}/{args.task_name}/{args.pooling_method}",
Expand Down
Loading
Loading