-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_linq_m7b.sh
86 lines (81 loc) · 2.79 KB
/
run_linq_m7b.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/bin/bash
#训练linq-mistra-7b-instruct
embedding_dir=linq_embeddings
model_path=/mnt/workspace/data/AIME/index/wjd/tevatron/sent_embedding/sentenc_models/Linq-Embed-Mistral
data_path=/mnt/workspace/data/AIME/index/wjd/tevatron/aqa_train_data_processed/train_qwen_0507.jsonl
corpus_data=/mnt/workspace/data/AIME/index/wjd/tevatron/aqa_train_data_processed/test_corpus_data.jsonl
query_data=/mnt/workspace/data/AIME/index/wjd/tevatron/aqa_train_data_processed/test_data_0606.tsv
output_dir=retriever-mistral-linq
deepspeed --include localhost:0,1 --master_port 52000 --module tevatron.retriever.driver.train \
--deepspeed deepspeed/ds_zero3_config.json \
--output_dir $output_dir \
--dataset_path $data_path \
--model_name_or_path $model_path \
--lora \
--lora_target_modules q_proj,v_proj \
--save_steps 200 \
--query_prefix "Instruct: Given a question, retrieve Wikipedia passages that answer the question\nQuery: " \
--passage_prefix "" \
--pooling eos \
--append_eos_token \
--normalize \
--report_to none \
--temperature 0.01 \
--per_device_train_batch_size 2 \
--gradient_checkpointing \
--train_group_size 6 \
--learning_rate 1e-4 \
--query_max_len 64 \
--passage_max_len 256 \
--num_train_epochs 3 \
--logging_steps 100 \
--overwrite_output_dir \
--gradient_accumulation_steps 4
#创建目录
mkdir -p $embedding_dir
#编码passage
for s in 0 1
do
CUDA_VISIBLE_DEVICES=$s python -m tevatron.retriever.driver.encode \
--output_dir=temp \
--model_name_or_path $model_path \
--lora_name_or_path $output_dir \
--lora \
--query_prefix "Instruct: Given a question, retrieve Wikipedia passages that answer the question\nQuery: " \
--passage_prefix "" \
--fp16 \
--pooling eos \
--append_eos_token \
--normalize \
--per_device_eval_batch_size 32 \
--query_max_len 64 \
--passage_max_len 256 \
--dataset_path $corpus_data \
--dataset_number_of_shards 2 \
--dataset_shard_index ${s} \
--encode_output_path $embedding_dir/corpus.${s}.pkl
done
#query 编码
CUDA_VISIBLE_DEVICES=1 python -m tevatron.retriever.driver.encode \
--output_dir=temp \
--model_name_or_path $model_path \
--lora_name_or_path $output_dir \
--lora \
--query_prefix "Instruct: Given a question, retrieve Wikipedia passages that answer the question\nQuery: " \
--passage_prefix "" \
--pooling eos \
--append_eos_token \
--normalize \
--encode_is_query \
--per_device_eval_batch_size 32 \
--query_max_len 512 \
--dataset_path $query_data \
--encode_output_path $embedding_dir/query-test.pkl
#检索输出top50
set -f && python -m tevatron.retriever.driver.search \
--query_reps $embedding_dir/query-test.pkl \
--passage_reps $embedding_dir/corpus*.pkl \
--depth 50 \
--batch_size 128 \
--save_text \
--save_ranking_to results/linq_7b_top50.txt