Skip to content

xiaokongkong/kuaisearch_learning_notes

Repository files navigation

KuaiSearch · 复现与工程化笔记

基于官方 KuaiSearch baseline(commit d38a080f)做的一次完整复现,配套工程化改造。保留原始论文的三大模块(Recall / Relevance / Ranking),在模型下载、多卡调度、脚本易用性、可复现性上做了加固,并在 KuaiSearch-Lite 子集上跑出了可对照的实验结果。

📄 原始官方 README 见 README_ori.md(保留了 baseline 仓库未改动的版本,便于对照看改了什么)。

官方项目:

模块详细笔记(逻辑、代码、实验):

  • 📘 notes/recall.md —— Recall(BM25 / DocT5Query / DPR / GR)数据处理、模型、实验记录
  • 📘 notes/relevance.md —— Relevance(Cross-Encoder / Bi-Encoder / GR)流程与指标
  • 📘 notes/ranking.md —— Ranking(DCNv1/v2 / DNN / WideDeep / DIN)特征装配与训练曲线

目录


改进一览

相对 baseline d38a080f,主要变更如下(共 48 个 commit,40 个文件,+734/-101 行):

1. 统一模型下载机制 common/model_resolver.py

  • 新增 ensure_model_downloaded(...),封装了 ModelScope 优先 / HuggingFace 兜底 的双源下载逻辑,任一源失败会自动切换。
  • 支持通过环境变量 MODEL_DOWNLOAD_PRIORITY=modelscope|huggingface 切换优先级。
  • 覆盖所有需要预训练模型的入口:
    • recall/BM25/doc2query.py(mT5-base)
    • recall/GR/train.py(mT5-base)
    • recall/dpr/train.py(bert-base-chinese)
    • recall/data/title_embedding_prepare.py(bge-small-zh-v1.5)
    • relevance/crossencoder/train.pyrelevance/embedding/train.pyrelevance/GR/train.py
    • ranking/data/process.py(本次也加了 fallback 逻辑)
  • 本地已有模型则直接复用,force_download=True 可强制重拉。

2. Shell 脚本加固 scripts/*.sh

为所有训练/评估脚本加了统一头部:

set -euo pipefail                # 出错即停
cd "${PROJECT_ROOT}"             # 始终以项目根为工作目录
export PYTHONPATH=...            # 让 common/ 可被 import
export WANDB_DISABLED=true       # 禁用 wandb 上报(因为我没有key,所以禁用了;需要的话可以再打开)
export WANDB_MODE=disabled
ulimit -n 65535                  # 放开 fd 限制

对多卡任务(recall_gr.shrecall_doc2query.shrelevance_embedding.shrelevance_gr.sh):

  • 把写死的 --nproc_per_node=8 改成可配置的 GPU 列表:默认 CUDA_VISIBLE_DEVICES=1,2,3,4,5,6NPROC_PER_NODE 自动按列表长度推断。
  • 通过 DOC2QUERY_GPUS / RECALL_GR_GPUS / RELEVANCE_EMBEDDING_GPUS / RELEVANCE_GR_GPUS 环境变量覆盖。
  • 新增 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,缓解长训练显存碎片问题。

3. Recall 模块改动

  • recall/BM25/config.jsonper_device_train_batch_size: 16 → 8gradient_accumulation_steps: 1 → 2,新增 optim: adafactor(mT5 显存占用大,改 Adafactor 更稳)。
  • recall/BM25/parameters.pyrecall/GR/train.pyrecall/dpr/train.py:新增 --download_priority 参数,report_to=[] 禁用 wandb。
  • recall/BM25/doc2query.pydataloader_num_workers: 8 → 2(减少 DataLoader 下子进程 OOM 风险),tokenizer 显式 legacy=True 规避 T5Tokenizer 警告。
  • recall/dpr/train.py:训练与推理强制 local_files_only=True,加 HF_HUB_OFFLINE=1TRANSFORMERS_OFFLINE=1,避免下载完又去联网校验。
  • recall/dpr/test.pyencoding_batch_size: 512 → 2048embedding_dir 默认路径调整为 ./doc_embeddingpretrained_model_path 调整为 ./model_single/best/doc_encoder
  • recall/data/construct_i2q.py:修正输入/输出相对路径(train_dpr.jsonrecall/data/train_dpr.json)。
  • recall/data/faiss_kmeans_quantize.py
    • 修复了从任意工作目录运行时的路径问题(embeddings/...recall/data/embeddings/...)。
    • 加了中文注释解释 FAISS 多层残差量化的每一步(聚类 → 残差 → 再聚类)。
    • 修正 code_check() 末尾保存路径 item_code.ptrecall/data/embeddings/item_code.pt
  • recall/data/process_pipeline.py:子进程指定 cwd=PROJECT_ROOT 并注入 PYTHONPATH,6 步流水线在任意目录启动都能跑通。
  • recall/data/title_embedding_prepare.py:改成 try/except fallback 模式——先直连加载,失败则通过 ensure_model_downloaded 从 ModelScope 拉下来再加载;同时修复 CPU-only 环境下的 device 兼容问题。
  • recall/GR/ 增加 code → item_id 反查 + item-level 指标:原 baseline 的 GR 评估只在「code 字符串空间」算 Recall@K,无法和 BM25/DPR/DocT5Query 的 item-level Recall 直接比较。改动:
    • recall/data/process4gr.py:在生成 item_text_codes.json 的同时额外输出 recall/data/code2itemids.json,提供两种粒度的反查表:full(key 为 "c0 c1 c2 dup")和 prefix3(key 为 "c0 c1 c2",用于消歧 dup 不准时的 fallback);测试集 JSON 每条记录新增 qiditem_ids 字段(与 output code 列表对齐)。
    • recall/GR/dataset.py:test 模式下 __getitem__item_ids 一并透传到 collator,避免按全局索引回查。
    • recall/GR/train.pyevaluate():新增 _load_code2itemids / _codes_to_itemids——把 50 个 beam 的 code 反查为 item_id 列表(rank-stable 去重,full 缺失时回退 3-prefix),再算 item_Recall@{10,20,50}、item_Hits@{10,20,50}、item_MRR@{10,20,50};同时保留 code-level 指标(code_Hits@Kcode_Recall@K)作为对照。
    • 新增 recall/GR/evaluate.py + scripts/recall_gr_eval.sh:可对已训练好的 DSI/best_model 单独重跑评估,无需重训,输出 metrics 到 DSI/best_model/test_metrics_item.json

4. Relevance 模块改动

  • relevance/data/process.py正/负样本打标规则修改 —— 原先直接用 label 字段;改为根据 score 字段,score == 3 为正样本(label=1),否则负样本(label=0)。匹配论文中"高相关性 = score 3"的定义。
  • relevance/*/parameters.py:新增 download_priority 字段,report_to=[] 禁用 wandb。
  • relevance/embedding/parameters.pySentenceTransformerTrainingArguments 在新版 sentence-transformers 里位置有变,加了 try/except 兼容两种导入路径。
  • relevance/{crossencoder,embedding,GR}/train.py:在模型加载前统一调用 ensure_model_downloaded(...),保证离线环境可用。

5. Ranking 模块改动

  • ranking/data/process.py
    • 加入 ensure_model_downloaded fallback,模型加载方式与 recall 保持一致。
    • 修复 JSON 读取的鲁棒性:json.loads 外套 try/except JSONDecodeError,遇到数据损坏行打印警告并跳过;字段改用 obj.get(...) 避免 KeyError
  • ranking/datasets.py
    • 使用 obj.get("split", "train") 兼容缺失 split 字段的行;统计并打印最终的 train / valid / test 样本数。
  • ranking/main.py:对 trainer.evaluate_test() 的返回值做 None 守护。当 test 集为空时打印 [TEST RESULT] skipped (no test set) 而不是直接在 f-string 里格式化 None(原 baseline 在这里会抛 TypeError: unsupported format string passed to NoneType.__format__)。
  • scripts/ranking_train.sh:显式 CUDA_VISIBLE_DEVICES=0,避免 accelerate 默认抢到错的卡。

6. 依赖升级 requirements.txt

整个 requirements.txt 已经用 pip freeze 刷成当前 env 的精确版本(Python 3.9.25,PyTorch 2.6.0 + CUDA 12.4,transformers 4.57.6,accelerate 1.10.1,sentence-transformers 5.1.2,modelscope 1.35.4,huggingface_hub 0.36.2 等),可以直接 pip install -r requirements.txt 复现,细节见 环境与安装

7. 其它

  • common/__init__.py:新建,方便包导入。
  • demo/items.jsonldemo/corpus.jsonl:与主流程命名对齐。
  • static/js/main.js:小修。
  • 新增多个 shell 脚本 wrapper(ranking_data_process.shranking_train.shrecall_bm25_eval.sh 等)。

环境与安装

# 推荐用 conda / mamba
conda create -n llm python=3.9 -y
conda activate llm
pip install -r requirements.txt

本次复现实际使用的版本

  • Python 3.9.25mambaforge env)
  • PyTorch 2.6.0 + CUDA 12.4(nvidia-*-cu12 系列 12.4.127)
  • transformers 4.57.6accelerate 1.10.1peft 0.17.1
  • sentence-transformers 5.1.2faiss-cpu 1.8.0
  • modelscope 1.35.4huggingface_hub 0.36.2
  • datasets 4.0.0numpy 1.26.4pandas 2.3.3scikit-learn 1.6.1

硬件

  • 多卡训练(recall / relevance)默认用 GPU 1,2,3,4,5,6(6 张);单卡环境请设 CUDA_VISIBLE_DEVICES=0
  • Ranking(DCNv1)单卡就能跑完(CUDA_VISIBLE_DEVICES=0)。

PyTorch ≥ 2.6 是必须的——新版 transformers 加载 .bin 权重时会校验 PyTorch 版本(CVE-2025-32434),见踩坑记录 第 1 条。


模型下载策略

训练脚本会在首次运行时自动下载所需的预训练模型到 ./model/ 下,目录结构示例:

model/
├── BAAI/bge-small-zh-v1.5        # title embedding & ranking encoder
├── google/mt5-base               # doc2query / GR
├── google-bert/bert-base-chinese # DPR
├── AI-ModelScope/xlm-roberta-base # relevance embedding
└── LLM-Research/Llama-3.2-3B-Instruct # relevance GR

切换下载源:

export MODEL_DOWNLOAD_PRIORITY=modelscope   # 默认
export MODEL_DOWNLOAD_PRIORITY=huggingface

命令行也可在对应训练脚本里传 --download_priority modelscope|huggingface 单独覆盖。


数据准备

pip install huggingface_hub
python -c "
from huggingface_hub import snapshot_download
snapshot_download(repo_id='benchen4395/KuaiSearch', repo_type='dataset', local_dir='./data')
"

⚠️ 用的是 KuaiSearch-Lite 子集:本次所有实验(recall / relevance / ranking)跑的都是 HuggingFace 上 benchen4395/KuaiSearch repo 里的 Lite 版本(item 数 6,634,118、session 数 161,740、用户数 102,086,体量约为完整版的 1/10)。完整版(KuaiSearch-Full)本地没跑,所以本仓库给出的指标只在 Lite 上有意义,不能直接和论文正式指标对比。

⚠️ 文件名兼容:下载下来的商品库文件名是 data/item.jsonl,但 baseline 代码和本仓库所有脚本(recall/data/build_dpr.pyranking/data/process.pyranking/datasets.py 等)默认读的是 data/corpus.jsonl。跑通之前务必先重命名:

mv data/item.jsonl data/corpus.jsonl
# 或者软链接保留原文件
ln -s item.jsonl data/corpus.jsonl

其它 jsonl(rank.jsonl / relevance.jsonl / users.jsonl / train.queries.tsv / train.qrels.tsv / test.queries.tsv / test.qrels.tsv)文件名和代码约定一致,不需要改动。


使用方式

所有脚本都必须从项目根目录(本仓库的 kuaisearch_learning_notes/)运行。新版脚本已经内置 cd 到项目根,所以从哪里调都行。

日志统一写到 logs/(首次运行请先 mkdir -p logs)。

Recall

# 1. 数据预处理(6 步流水线)
nohup bash scripts/recall_data_process.sh > logs/recall_data_process.txt 2>&1 &

# 2a. BM25
nohup bash scripts/recall_bm25_eval.sh    > logs/recall_bm25.txt 2>&1 &

# 2b. DocT5Query(先训 mT5 伪 query 生成,再按 BM25 评估扩展语料)
nohup bash scripts/recall_doc2query.sh        > logs/recall_doc2query.txt 2>&1 &
nohup bash scripts/recall_docT5query_eval.sh  > logs/recall_docT5query.txt 2>&1 &

# 2c. DPR(BERT 双塔稠密检索)
nohup bash scripts/recall_dpr.sh > logs/recall_dpr.txt 2>&1 &

# 2d. Generative Retrieval(基于 mT5 的 DSI)
nohup bash scripts/recall_gr.sh > logs/recall_gr.txt 2>&1 &

# 2e. (可选)GR item-level 评估:对已训练好的 best_model 重跑指标,
#      会用 recall/data/code2itemids.json 把生成的 50 个 beam code
#      反查成 item_id 列表,再算 Recall@K / Hits@K / MRR@K(与 BM25/DPR 同口径)。
nohup bash scripts/recall_gr_eval.sh > logs/recall_gr_eval.txt 2>&1 &

可自定义 GPU:

RECALL_GR_GPUS="0,1,2,3" bash scripts/recall_gr.sh

Relevance

# 1. 数据预处理(score==3 判正)
nohup bash scripts/relevance_data_process.sh > logs/relevance_data_process.txt 2>&1 &

# 2a. Cross-Encoder
nohup bash scripts/relevance_crossencoder.sh > logs/relevance_crossencoder.txt 2>&1 &

# 2b. Bi-Encoder / Embedding(XLM-RoBERTa)
nohup bash scripts/relevance_embedding.sh > logs/relevance_embedding.txt 2>&1 &

# 2c. Generative Relevance(Llama-3.2-3B)
nohup bash scripts/relevance_gr.sh > logs/relevance_gr.txt 2>&1 &

Ranking

# 1. 特征/向量预处理(会用 bge-small-zh-v1.5 做 query/item 编码)
nohup bash scripts/ranking_data_process.sh > logs/ranking_data_process.txt 2>&1 &

# 2. 训练(默认 DCNv1,单卡)
nohup bash scripts/ranking_train.sh > logs/ranking_train.txt 2>&1 &

实验结果(本次复现)

以下数字均取自 logs/,均为 KuaiSearch-Lite 子集上的本地复现,配置为 6 张 GPU(CUDA_VISIBLE_DEVICES=1..6)。

Recall

方法 recall@10 mrr@10 recall@20 mrr@20 recall@50 mrr@50 recall@100 mrr@100
BM25 (logs/recall_bm25.txt) 0.0686 0.0422 0.1014 0.0450 0.1509 0.0470 0.2006 0.0478
DocT5Query (logs/recall_docT5query.txt) 0.0779 0.0473 0.1138 0.0505 0.1767 0.0530 0.2343 0.0540
DPR (logs/recall_dpr.txt) 0.0584 0.0371 0.0919 0.0401 0.1491 0.0424 0.2011 0.0433

⚠️ recall_dpr.txt 训练结尾出现 NCCL 超时(AllReduce 600s timeout)导致进程异常退出,但同一份日志里已经保留了评估阶段的完整结果(STEP 3: Evaluating Performance 段)。DPR 训练在后期存在多卡通信不稳定的问题,可根据需要考虑降 world size 或增大 NCCL timeout。

Generative Retrieval (GR, DSI)

30 个 epoch、9h53m(logs/recall_gr.txt):

Hits@10 Recall@10 Hits@20 Recall@20 Hits@50 Recall@50
0.0947 0.0611 0.1305 0.0867 0.1947 0.1307
  • final train loss: 1.102
  • best eval metric: 2.155 @ checkpoint-213630

Relevance(测试集)

方法 Accuracy Macro F1 Weighted F1 ROC-AUC PR-AUC
Cross-Encoder (relevance_crossencoder.txt) 0.7331 0.6568 0.7115 0.7607 0.5997
Bi-Encoder / Embedding (relevance_embedding.txt) 0.6523 0.6161 0.6537 0.6652 0.4902
Generative Relevance (relevance_gr.txt) 0.7398 0.7124 0.7424 0.8033 0.6623

Bi-Encoder 在训练集上的最优阈值:0.7029(以 macro_f1 为目标优化,Train Accuracy 0.6578 / Macro F1 0.6159 / ROC-AUC 0.6677)。

Ranking

scripts/ranking_train.sh 走通(DCNv1、batch_size=20000num_epochs=20,单卡 CUDA_VISIBLE_DEVICES=0),训练+验证完整跑完。

数据加载情况logs/ranking_train.txt):

[INFO] item_emb:  shape=(6634118, 512)
[INFO] query_emb: shape=(161740, 512)
[INFO] Loaded users   = 102086
[INFO] Loaded items   = 6634118
[INFO] All-train candidates (split='train') = 7439261
[INFO] Test samples   (split='test') = 0
[INFO] Final Train samples = 6695335
[INFO] Final Valid samples = 743926 (10%)
[INFO] Final Test  samples = 0

⚠️ ranking 测试集为空ranking/datasets.pyrank.jsonl 中每行的 "split" 字段切分训练/测试,但 KuaiSearch 公开数据中所有行的 split 都是 "train",因此 Test samples = 0。训练过程中每个 epoch 都会打印 [WARN] No test_loader provided, skip test evaluation.,最终测试集指标(LogLoss / AUC)无法得到。我们在 ranking/main.py 里加了 None 守护,打印 [TEST RESULT] skipped (no test set) 而不是抛异常。

训练过程(epoch 1–14,early stopping at epoch 14,patience=2):

epoch train loss valid loss valid AUC
1 0.1623 0.1461 0.6106
2 0.1467 0.1452 0.6255
3 0.1456 0.1451 0.6369
4 0.1446 0.1439 0.6475
5 0.1438 0.1435 0.6479
6 0.1432 0.1434 0.6533
7 0.1426 0.1424 0.6613
8 0.1416 0.1422 0.6710
9 0.1399 0.1406 0.6896
10 0.1379 0.1390 0.7023
11 0.1363 0.1385 0.7079
12 0.1346 0.1384 0.7104 ← best
13 0.1289 0.1403 0.7044
14 0.1085 0.1526 0.6704
  • 最优 checkpoint:epoch 12valid_loss=0.138382valid_auc=0.710396,保存到 ./checkpoints_widedeep/best_model.pt
  • 单 epoch 训练时间约 7–8 分钟(335 iter,batch_size=20000)。
  • 由于缺少测试集,无法给出 test LogLoss / AUC;如果想获得测试指标,需要自行从训练集中划出一部分标记为 split="test"(或修改 ranking/datasets.py 强制按比例切分)。

踩坑记录

1. torch.load 安全漏洞限制模型加载(CVE-2025-32434)

现象:运行 bash scripts/recall_doc2query.sh 时报错:

ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`,
we now require users to upgrade torch to at least v2.6 ...

原因:新版 transformers 强制要求加载 .bin 格式权重时 PyTorch ≥ 2.6。

解决

# 方案 A:把 torch 升到 >=2.6
# 方案 B:加载 safetensors
model = MT5ForConditionalGeneration.from_pretrained(
    "google/mt5-base",
    use_safetensors=True,
)

2. HuggingFace 直连失败

现象SentenceTransformer("BAAI/bge-small-zh-v1.5") 直接超时或 403。

解决:已在脚本里改成 try/except fallback——直连失败会走 ensure_model_downloaded 从 ModelScope 拉取。也可手动:

modelscope download --model AI-ModelScope/xlm-roberta-base --local_dir ./model/XLM-roberta-base
modelscope download --model LLM-Research/Llama-3.2-3B-Instruct --local_dir ./model/meta-llama/Llama-3.2-3B

3. 多卡 NCCL 超时

现象:DPR 训练后期报 Watchdog caught collective operation timeout: WorkNCCL ... ran for 600033 milliseconds before timing out

缓解

  • 减少 world size(比如 6 → 4);
  • 或导出 NCCL_BLOCKING_WAIT=1NCCL_ASYNC_ERROR_HANDLING=1、提高 timeout

4. rank.jsonl 数据行损坏

现象:ranking 预处理时 json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 67

解决ranking/data/process.pycollect_session_queries / collect_items 已加 try/except,遇到坏行打印警告并跳过;同时建议 wc -l data/rank.jsonl 核对文件是否被截断。

5. Ranking 测试集缺失导致训练末尾崩溃

现象:DCNv1 训完 14 个 epoch、触发 early stopping 并加载 best checkpoint 之后,在 trainer.evaluate_test() 返回 (None, None) 时,main.py 里原先的 f"...{test_loss:.6f}..." 直接抛 TypeError: unsupported format string passed to NoneType.__format__,进程以非 0 退出。

根因rank.jsonl 里所有行的 split 字段都是 "train",没有 "test",导致 Test samples = 0test_loader 为空。

解决:已在 ranking/main.py 里加 None 守护,改为打印 [TEST RESULT] skipped (no test set),训练正常结束。如需测试集指标,可以改 ranking/datasets.py 里的划分逻辑,按比例从 raw_all_train 里再抽一份作为 test(当前实现里 test 严格按 split == "test" 判定)。


引用

如在研究中使用 KuaiSearch 数据集,请引用原论文:

@article{li2026kuaisearch,
  title   = {KuaiSearch: A Large-Scale E-Commerce Search Dataset for Recall, Ranking, and Relevance},
  author  = {Yupeng Li and Ben Chen and Mingyue Cheng and Zhiding Liu and Xuxin Zhang and Chenyi Lei and Wenwu Ou},
  journal = {arXiv preprint arXiv:2602.11518},
  year    = {2026},
  url     = {https://arxiv.org/abs/2602.11518}
}

About

复现 KuaiSearch 三大搜索模块(召回 / 相关性 / 排序)的工程化学习笔记 — Recall (BM25 / DocT5Query / DPR / GR), Relevance (CE / Bi-E / LLM), Ranking (DCN / DIN / W&D);含 ModelScope 兜底下载、多卡脚本加固、逐模块代码走读、本地指标与踩坑记录。

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors