Skip to content

Commit 0a62dff

Browse files
add bge_m3,flag_reranker,m3_base.
1 parent 12175f0 commit 0a62dff

File tree

8 files changed

+407
-0
lines changed

8 files changed

+407
-0
lines changed

pytorch/Embedding/bge_m3/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ pip install -r requirements.txt
1616
```shell
1717
python test_bge_m3.py
1818
```
19+
20+
4. performance
21+
```shell
22+
python perf_bge_m3_big.py
23+
```
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from FlagEmbedding import BGEM3FlagModel
2+
import torch
3+
import torch_musa
4+
import time
5+
import numpy as np
6+
import concurrent.futures
7+
import random
8+
import string
9+
from tqdm import tqdm
10+
11+
# 生成1024 tokens的长文本(约1500-1800字符)
12+
def generate_long_text(target_tokens=1024):
13+
"""生成符合目标token长度的随机文本"""
14+
words = []
15+
current_tokens = 0
16+
while current_tokens < target_tokens:
17+
word_len = random.randint(3, 10)
18+
words.append(''.join(random.choices(string.ascii_letters, k=word_len)))
19+
current_tokens += 1
20+
return " ".join(words)
21+
22+
def process_batch(model, batch_sentences, max_length):
23+
"""处理单个批次并返回结果和耗时"""
24+
start_time = time.time()
25+
embeddings = model.encode(batch_sentences, max_length=max_length)['dense_vecs']
26+
end_time = time.time()
27+
return embeddings, end_time - start_time
28+
29+
def warmup_model(model, batch_size=32, max_length=512, iterations=10):
30+
"""执行模型预热以消除冷启动影响"""
31+
warmup_sentences = ["Warmup sentence " * 20] * batch_size # 模拟长文本
32+
for _ in range(iterations):
33+
model.encode(warmup_sentences, max_length=max_length)['dense_vecs']
34+
35+
if __name__ == '__main__':
36+
# 初始化模型
37+
model = BGEM3FlagModel('./bge-m3', use_fp16=True, device='musa:0')
38+
39+
# ===== 关键优化1:生成1024 tokens的长文本 =====
40+
print("=== Generating 1024-token texts ===")
41+
long_query = generate_long_text(1024)
42+
long_passage = generate_long_text(1024)
43+
44+
# ===== 关键优化2:添加长文本预热 =====
45+
print("\n=== Starting model warm-up with long texts ===")
46+
warmup_model(model, batch_size=32, max_length=1024, iterations=10)
47+
print("=== Warm-up completed ===\n")
48+
49+
# ===== 准备30个批次的并行任务 =====
50+
batch_pairs_list = []
51+
for _ in range(30):
52+
batch = []
53+
for _ in range(32): # 每个批次32个样本
54+
q = generate_long_text(1024) if random.random() > 0.5 else long_query
55+
p = generate_long_text(1024) if random.random() > 0.5 else long_passage
56+
batch.append(q)
57+
batch.append(p)
58+
batch_pairs_list.append(batch)
59+
60+
# ===== 并行执行30个批次 =====
61+
print("=== Starting 30 parallel batch processing ===")
62+
total_tokens = 0
63+
for batch in batch_pairs_list:
64+
total_tokens += sum(len(text.split()) for text in batch)
65+
66+
start_time = time.time()
67+
batch_results = []
68+
batch_times = []
69+
70+
# 使用线程池并行处理
71+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
72+
futures = [executor.submit(process_batch, model, batch, 1024) for batch in batch_pairs_list]
73+
74+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
75+
embeddings, batch_time = future.result()
76+
batch_results.append(embeddings)
77+
batch_times.append(batch_time)
78+
79+
end_time = time.time()
80+
total_time = end_time - start_time
81+
82+
# ===== 性能指标计算 =====
83+
# 1. 吞吐量指标
84+
throughput_batches = len(batch_results) / total_time
85+
throughput_tokens = total_tokens / total_time
86+
87+
# 2. 延迟指标
88+
avg_batch_time = sum(batch_times) / len(batch_times)
89+
max_batch_time = max(batch_times)
90+
min_batch_time = min(batch_times)
91+
92+
# ===== 性能报告 =====
93+
print("\n===== Performance Report =====")
94+
print(f"Total batches processed: {len(batch_results)}")
95+
print(f"Total tokens processed: {total_tokens}")
96+
print(f"Total processing time: {total_time:.2f} seconds")
97+
print("\n--- Throughput ---")
98+
print(f"Throughput (batches/sec): {throughput_batches:.2f}")
99+
print(f"Throughput (tokens/sec): {throughput_tokens:.2f}")
100+
print("\n--- Latency ---")
101+
print(f"Avg batch time: {avg_batch_time:.4f} sec")
102+
print(f"Max batch time: {max_batch_time:.4f} sec")
103+
print(f"Min batch time: {min_batch_time:.4f} sec")
104+
print("=============================")
105+
106+
# 示例相似度计算
107+
embeddings_1 = batch_results[0][0:2] # 取第一个批次的前两个查询
108+
embeddings_2 = batch_results[0][2:4] # 取第一个批次的前两个段落
109+
similarity = np.dot(embeddings_1, embeddings_2.T)
110+
print(f"\nSample similarity matrix:\n{similarity}")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
0. Start docker
2+
启动命令可参考: [README.md](../../README.md)
3+
4+
1. Prerequisites
5+
```shell
6+
pip install -r requirements.txt
7+
8+
pip install -U huggingface_hub
9+
```
10+
2. export env
11+
```shell
12+
export export HF_ENDPOINT=https://hf-mirror.com
13+
```
14+
15+
3. Test
16+
```shell
17+
python perf_flag_reranker.py
18+
```
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from FlagEmbedding import FlagReranker
2+
import time
3+
import numpy as np
4+
import concurrent.futures
5+
import random
6+
import string
7+
8+
# 生成长文本(1024 tokens约1500-1800字符)
9+
def generate_long_text(target_tokens=1024):
10+
words = []
11+
current_tokens = 0
12+
while current_tokens < target_tokens:
13+
word_len = random.randint(3, 10)
14+
words.append(''.join(random.choices(string.ascii_letters, k=word_len)))
15+
current_tokens += 1
16+
return " ".join(words)
17+
18+
def process_batch(reranker, batch_pairs):
19+
"""处理单个批次并返回结果和耗时"""
20+
start_time = time.perf_counter()
21+
scores = reranker.compute_score(batch_pairs)
22+
end_time = time.perf_counter()
23+
return scores, end_time - start_time
24+
25+
def main():
26+
# 加载模型(FP16精度 + MUSA设备加速)
27+
reranker = FlagReranker('BAAI/bge-reranker-large',
28+
use_fp16=True,
29+
device="musa")
30+
31+
# ===== 长文本优化:生成1024 tokens的输入 =====
32+
print("=== Generating 1024-token texts ===")
33+
long_query = generate_long_text(1024)
34+
long_passage = generate_long_text(1024)
35+
36+
# ===== 关键优化:添加模型预热(使用长文本)===== [1,6](@ref)
37+
print("=== Starting model warm-up with long texts ===")
38+
warmup_pairs = [[long_query, long_passage]] * 16
39+
for _ in range(5):
40+
reranker.compute_score(warmup_pairs)
41+
print("=== Warm-up completed ===\n")
42+
43+
# 单次长文本推理测试
44+
start_time = time.perf_counter()
45+
score = reranker.compute_score([long_query, long_passage])
46+
latency = (time.perf_counter() - start_time) * 1000
47+
print(f"Long text score: {str(score)} | Latency: {latency:.2f} ms")
48+
49+
# 准备批量数据(30个并行任务)
50+
batch_pairs_list = []
51+
for _ in range(30):
52+
pairs = []
53+
for _ in range(64): # 每个任务64个样本
54+
q = generate_long_text(1024) if random.random() > 0.5 else long_query
55+
p = generate_long_text(1024) if random.random() > 0.5 else long_passage
56+
pairs.append([q, p])
57+
batch_pairs_list.append(pairs)
58+
59+
# ===== 并行执行30个任务 =====
60+
print("\n=== Starting 30 parallel batch processing ===")
61+
total_tokens = sum(
62+
sum(len(q.split()) + len(p.split()) for q, p in pairs)
63+
for pairs in batch_pairs_list
64+
)
65+
batch_times = []
66+
start_time = time.perf_counter()
67+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: # 控制并发数
68+
futures = [executor.submit(process_batch, reranker, pairs) for pairs in batch_pairs_list]
69+
70+
batch_results = []
71+
for future in concurrent.futures.as_completed(futures):
72+
scores, batch_time = future.result()
73+
batch_results.append(scores)
74+
batch_times.append(batch_time)
75+
76+
total_time = time.perf_counter() - start_time
77+
78+
# 性能统计
79+
total_pairs = 30 * 64 # 30任务 * 每任务64对
80+
throughput_pairs = total_pairs / total_time
81+
throughput_tokens = total_tokens / total_time
82+
83+
avg_batch_time = sum(batch_times) / len(batch_times)
84+
max_batch_time = max(batch_times)
85+
min_batch_time = min(batch_times)
86+
87+
print("\n===== Performance Report =====")
88+
print(f"Total batches processed: {len(batch_results)}")
89+
print(f"Total pairs processed: {total_pairs}")
90+
print(f"Total tokens processed: {total_tokens}")
91+
print(f"Total processing time: {total_time:.2f} seconds")
92+
93+
print("\n--- Throughput ---")
94+
print(f"Throughput: {throughput_pairs:.2f} pairs/sec")
95+
print(f"Token throughput: {throughput_tokens:.2f} tokens/sec")
96+
97+
print("\n--- Latency ---")
98+
print(f"Average batch time: {avg_batch_time:.4f} sec")
99+
print(f"Max batch time: {max_batch_time:.4f} sec")
100+
print(f"Min batch time: {min_batch_time:.4f} sec")
101+
102+
print("=============================")
103+
104+
105+
106+
if __name__ == "__main__":
107+
main()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
FlagEmbedding
2+
accelerate==1.0.1
3+
transformers==4.44.0
4+
peft
5+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
0. Start docker
2+
启动命令可参考: [README.md](../../README.md)
3+
4+
1. Prerequisites
5+
```shell
6+
pip install -r requirements.txt
7+
8+
pip install -U huggingface_hub
9+
```
10+
2. export env
11+
```shell
12+
export export HF_ENDPOINT=https://hf-mirror.com
13+
```
14+
15+
3. Test
16+
```shell
17+
python perf_m3_base.py
18+
```

0 commit comments

Comments
 (0)