|
| 1 | +from FlagEmbedding import BGEM3FlagModel |
| 2 | +import torch |
| 3 | +import torch_musa |
| 4 | +import time |
| 5 | +import numpy as np |
| 6 | +import concurrent.futures |
| 7 | +import random |
| 8 | +import string |
| 9 | +from tqdm import tqdm |
| 10 | + |
| 11 | +# 生成1024 tokens的长文本(约1500-1800字符) |
| 12 | +def generate_long_text(target_tokens=1024): |
| 13 | + """生成符合目标token长度的随机文本""" |
| 14 | + words = [] |
| 15 | + current_tokens = 0 |
| 16 | + while current_tokens < target_tokens: |
| 17 | + word_len = random.randint(3, 10) |
| 18 | + words.append(''.join(random.choices(string.ascii_letters, k=word_len))) |
| 19 | + current_tokens += 1 |
| 20 | + return " ".join(words) |
| 21 | + |
| 22 | +def process_batch(model, batch_sentences, max_length): |
| 23 | + """处理单个批次并返回结果和耗时""" |
| 24 | + start_time = time.time() |
| 25 | + embeddings = model.encode(batch_sentences, max_length=max_length)['dense_vecs'] |
| 26 | + end_time = time.time() |
| 27 | + return embeddings, end_time - start_time |
| 28 | + |
| 29 | +def warmup_model(model, batch_size=32, max_length=512, iterations=10): |
| 30 | + """执行模型预热以消除冷启动影响""" |
| 31 | + warmup_sentences = ["Warmup sentence " * 20] * batch_size # 模拟长文本 |
| 32 | + for _ in range(iterations): |
| 33 | + model.encode(warmup_sentences, max_length=max_length)['dense_vecs'] |
| 34 | + |
| 35 | +if __name__ == '__main__': |
| 36 | + # 初始化模型 |
| 37 | + model = BGEM3FlagModel('./bge-m3', use_fp16=True, device='musa:0') |
| 38 | + |
| 39 | + # ===== 关键优化1:生成1024 tokens的长文本 ===== |
| 40 | + print("=== Generating 1024-token texts ===") |
| 41 | + long_query = generate_long_text(1024) |
| 42 | + long_passage = generate_long_text(1024) |
| 43 | + |
| 44 | + # ===== 关键优化2:添加长文本预热 ===== |
| 45 | + print("\n=== Starting model warm-up with long texts ===") |
| 46 | + warmup_model(model, batch_size=32, max_length=1024, iterations=10) |
| 47 | + print("=== Warm-up completed ===\n") |
| 48 | + |
| 49 | + # ===== 准备30个批次的并行任务 ===== |
| 50 | + batch_pairs_list = [] |
| 51 | + for _ in range(30): |
| 52 | + batch = [] |
| 53 | + for _ in range(32): # 每个批次32个样本 |
| 54 | + q = generate_long_text(1024) if random.random() > 0.5 else long_query |
| 55 | + p = generate_long_text(1024) if random.random() > 0.5 else long_passage |
| 56 | + batch.append(q) |
| 57 | + batch.append(p) |
| 58 | + batch_pairs_list.append(batch) |
| 59 | + |
| 60 | + # ===== 并行执行30个批次 ===== |
| 61 | + print("=== Starting 30 parallel batch processing ===") |
| 62 | + total_tokens = 0 |
| 63 | + for batch in batch_pairs_list: |
| 64 | + total_tokens += sum(len(text.split()) for text in batch) |
| 65 | + |
| 66 | + start_time = time.time() |
| 67 | + batch_results = [] |
| 68 | + batch_times = [] |
| 69 | + |
| 70 | + # 使用线程池并行处理 |
| 71 | + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: |
| 72 | + futures = [executor.submit(process_batch, model, batch, 1024) for batch in batch_pairs_list] |
| 73 | + |
| 74 | + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): |
| 75 | + embeddings, batch_time = future.result() |
| 76 | + batch_results.append(embeddings) |
| 77 | + batch_times.append(batch_time) |
| 78 | + |
| 79 | + end_time = time.time() |
| 80 | + total_time = end_time - start_time |
| 81 | + |
| 82 | + # ===== 性能指标计算 ===== |
| 83 | + # 1. 吞吐量指标 |
| 84 | + throughput_batches = len(batch_results) / total_time |
| 85 | + throughput_tokens = total_tokens / total_time |
| 86 | + |
| 87 | + # 2. 延迟指标 |
| 88 | + avg_batch_time = sum(batch_times) / len(batch_times) |
| 89 | + max_batch_time = max(batch_times) |
| 90 | + min_batch_time = min(batch_times) |
| 91 | + |
| 92 | + # ===== 性能报告 ===== |
| 93 | + print("\n===== Performance Report =====") |
| 94 | + print(f"Total batches processed: {len(batch_results)}") |
| 95 | + print(f"Total tokens processed: {total_tokens}") |
| 96 | + print(f"Total processing time: {total_time:.2f} seconds") |
| 97 | + print("\n--- Throughput ---") |
| 98 | + print(f"Throughput (batches/sec): {throughput_batches:.2f}") |
| 99 | + print(f"Throughput (tokens/sec): {throughput_tokens:.2f}") |
| 100 | + print("\n--- Latency ---") |
| 101 | + print(f"Avg batch time: {avg_batch_time:.4f} sec") |
| 102 | + print(f"Max batch time: {max_batch_time:.4f} sec") |
| 103 | + print(f"Min batch time: {min_batch_time:.4f} sec") |
| 104 | + print("=============================") |
| 105 | + |
| 106 | + # 示例相似度计算 |
| 107 | + embeddings_1 = batch_results[0][0:2] # 取第一个批次的前两个查询 |
| 108 | + embeddings_2 = batch_results[0][2:4] # 取第一个批次的前两个段落 |
| 109 | + similarity = np.dot(embeddings_1, embeddings_2.T) |
| 110 | + print(f"\nSample similarity matrix:\n{similarity}") |
0 commit comments