Turn Detection(对话轮次检测)是一个用于人机对话系统中的关键技术,主要用于:
- 对话边界识别:准确判断用户何时结束当前发言,避免对话系统过早或过晚响应
- 多轮对话管理:在连续对话中识别每个对话轮次的开始和结束,提升对话体验
- 实时交互优化:通过精准的轮次检测,实现更自然流畅的人机交互
- 语音助手增强:为语音助手、客服机器人等应用提供更智能的对话控制
模型基于gemma3 270M模型进行微调,提供了完整的数据集和微调脚本。 效果媲美7B模型效果。
-
能够处理复杂的多轮对话场景
-
准确识别对话中的停顿、思考和真正的轮次结束
-
支持上下文感知的轮次判断
支持多轮对话的重要性:
user: 我们来个成语接龙吧? assistant: 那我先来,杞人忧天。该你了 user: 天天向上
非多轮对话下单一的"天天向上"是不完整的,但是放在上下文中则应该是完整的。
- 模型参数仅270M,资源占用低
- 支持CPU推理,无需GPU即可部署
- 推理速度快,满足实时对话需求
- 适合边缘设备和资源受限环境
- 原生支持中文和英文对话检测
- 模型架构支持微调扩展到其他语言
- 跨语言泛化能力强
- 提供完整的微调框架
- 支持针对特定领域和语言的定制训练
- 灵活的数据处理和训练流程
- 0 (不完整):用户话语未说完,需要等待继续输入
- 1 (完整):用户话语表达完整,可以进行回复
- 2 (要求等待):用户要求暂停或打断AI回复
中文单轮和多轮数据:使用LLM合成 英文单轮和多轮数据:turns-2k数据集使用LLM扩展为多轮
使用 finetune.py
进行模型微调:
pip install -r requirements.txt
python finetune.py
如果微调的过程中出现下面的错误,unsloth依赖的triton版本过高,需要卸载triton版本,重新安装triton-3.2.0版本
pip uninsatll triton
pip install triton==3.2.0
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
中文准确率: 0.9591 (258/269)
英文准确率: 0.9654 (223/231)
总体准确率: 0.9620 (481/500)
Nvidia T4单卡推理耗时: <100ms (P95<80ms)
- 中文场景使用中文的system prompt,英文场景使用英文的system prompt
- wait场景在多轮对话中才有效,结合实际使用场景,训练集中wait场景均为多轮对话。
- 训练数据中未使用通用数据集进行配比训练,所以通用能力可能会有下降。如果需要通用能力请在当前数据集基础上添加通用数据集进行训练,通常做1:1配比即可。
国内访问huggingface遇到网络问题时,可以设置
# For Linux or MacOS
export HF_ENDPOINT=https://hf-mirror.com
或
# For Windows PowerShell
$env:HF_ENDPOINT = "https://hf-mirror.com"
# 启动HTTP API服务
vllm serve gemma3-270m-full-turn-detection --served-model-name=gemma3 --port 8000 --enable-prefix-caching --gpu-memory-utilization 0.8
# 调用API
curl -X POST http://localhost:8000/v1 \
-H "Content-Type: application/json" \
-d '{"audio_data": "base64_encoded_audio"}'
也兼容openAI库。
from inference import TurnDetector
# 初始化检测器
detector = TurnDetector(
model_path="gemma3-270m-full-turn-detection", # 模型路径
device="auto" # 自动选择设备,也可以指定"cpu"或"cuda"
)
# 方式1: 字符串格式输入
conversation_str = """user: 我们来成语接龙吧
assistant: 杞人忧天
user: 天天向上"""
result = detector.detect(conversation_str)
print(f"检测结果: {result}") # 0-完整, 1-不完整, 2-要求等待
# 方式2: 消息列表格式输入
conversation_msgs = [
{"role": "user", "content": "我们来成语接龙吧"},
{"role": "assistant", "content": "杞人忧天"},
{"role": "user", "content": "天天向上"}
]
result = detector.detect(conversation_msgs)
print(f"检测结果: {result}")
# 方式3: 获取详细结果
detailed_result = detector.detect_with_explanation(conversation_str)
print(f"状态码: {detailed_result['status_code']}")
print(f"状态名: {detailed_result['status_name']}")
print(f"说明: {detailed_result['description']}")
# 方式4: 批量检测
conversations = [
"user: 我想要...",
"user: 停",
"user: 今天天气很好"
]
results = detector.detect_batch(conversations)
print(f"批量检测结果: {results}") # [1, 2, 0]
# 交互式模式
python inference.py --interactive
# 单次检测
python inference.py --input "user: 我想要..."
# 批量检测
python inference.py --input_file conversations.json --output_file results.json
# 指定设备和参数
python inference.py --device cpu --temperature 0.1 --interactive
# 演示示例
python inference.py
# 启动HTTP API服务
vllm serve gemma3-270m-full-turn-detection --gpu-memory-utilization 0.8 --enable-prefix-caching --served-model-name=gemma3-turn-detection --port 8080
# 调用API
curl -X POST "http://localhost:8080/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-xx" \
-d '{
"model": "gemma3-turn-detection",
"temperature": 1.0,
"top_p": 0.95,
"top_k": 64,
"messages": [
{
"role": "system",
"content": "你是一个专门分析对话状态的AI助手。请根据对话历史,判断用户最后说的话属于以下哪种状态:\n\n**状态定义:**\n- 0 (不完整):用户的话语表达完整,意思清晰明确,不需要继续补充\n- 1 (完整):用户的话语未说完,存在停顿、犹豫或明显的未完成表达\n- 2 (要求等待):用户明确表示要打断或暂停AI的回复,要求AI停止说话或等待\n\n**判断标准:**\n\n**不完整(0)的特征:**\n- 句子突然中断,没有完整表达意思\n- 包含停顿词:如"嗯"、"那个"、"就是"、"呃"等\n- 语句结构不完整,明显还有后续内容\n- 例如:"我想要..."、"关于这个问题,嗯..."、"山字怎么"\n\n**完整(1)的特征:**\n- 句子结构完整,语法正确\n- 表达了清晰的意图或完整的信息\n- 没有明显的停顿词或未完成标记\n- 例如:"我想去北京旅游"、"今天天气很好"、"谢谢你的帮助"\n\n**要求等待(2)的特征:**\n- 明确的打断指令:如"停"、"等等"、"暂停"、"闭嘴"\n- 礼貌的暂停请求:如"稍等"、"等一下"、"先别说"\n- 表达需要时间思考:如"让我想想"、"我需要安静"\n- 表达不耐烦:如"够了"、"太多了"、"别说了"\n- 英文打断:如"Stop"、"Wait"、"Hold on"、"Shut up"、"Enough"\n\n\n**输出格式:**\n你只能输出[0、1、2]中的其中一个数字,不要输出其他的。"
},
{
"role": "user",
"content": "请分析以下对话中用户最后说的话:\nuser: 我们来成语接龙吧\nassistant: 杞人忧天\nuser: 停"
}
]
}'
- 可以基于提供的训练脚本新增其他语种的语料进行继续微调。每个语种在200条数据即可达到比较好的效果
- 模型可以量化以进一步降低资源占用,提升推理效率。
- Unsloth: 优秀的微调框架
- Gemma3: 优秀的开源模型权重
- ten-turn-detection: 参考了其wait数据集,并对比了其模型效果
This project is Apache 2.0 licensed with certain conditions.