-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtiny_rag.py
More file actions
325 lines (276 loc) · 10.7 KB
/
tiny_rag.py
File metadata and controls
325 lines (276 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
基于硅基流动线上资源的最小 RAG 示例:
1. 读取本地 txt 文件,切片并调用嵌入接口存入 sqlite。
2. 接收用户问题,生成嵌入,在本地向量库里检索相似片段。
3. 将片段连同问题一并提交给硅基流动大模型,得到回答。
准备工作:
1. 安装依赖:pip install -r requirements.txt
2. 修改 API_KEY 常量。
"""
from __future__ import annotations
import json
import math
import os
import sqlite3
from pathlib import Path
from typing import Iterable, List, Sequence, Tuple
import requests
API_BASE = "https://api.siliconflow.cn/v1"
CHAT_MODEL = "Qwen/Qwen3-8B" # 对话大模型
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5" # 知识库文档索引模型
API_KEY = "your_api_key"
DB_PATH = Path("knowledge_base.db") # 知识库数据库文件
DOC_DIR = Path("doc") # 知识库文档目录
DEFAULT_TOP_K = 3 # 默认召回片段数量
DEFAULT_THRESHOLD = 0.3 # 默认相似度阈值
# ----------------------------- HTTP helpers ----------------------------- #
def _headers() -> dict:
if not API_KEY or API_KEY == "your_api_key":
raise RuntimeError("请先配置有效的 API_KEY(可使用环境变量或修改常量)")
return {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
}
def fetch_embedding(text: str) -> List[float]:
payload = {"model": EMBEDDING_MODEL, "input": text}
resp = requests.post(f"{API_BASE}/embeddings", headers=_headers(), json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
return data["data"][0]["embedding"]
def chat_completion(messages: List[dict], temperature: float = 0.7) -> str:
payload = {
"model": CHAT_MODEL,
"messages": messages,
"temperature": temperature,
"stream": False,
}
resp = requests.post(f"{API_BASE}/chat/completions", headers=_headers(), json=payload, timeout=120)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"]
def chat_completion_stream(messages: List[dict], temperature: float = 0.7) -> Iterable[str]:
payload = {
"model": CHAT_MODEL,
"messages": messages,
"temperature": temperature,
"stream": True,
}
with requests.post(
f"{API_BASE}/chat/completions",
headers=_headers(),
json=payload,
stream=True,
timeout=300,
) as resp:
resp.raise_for_status()
for raw_line in resp.iter_lines(decode_unicode=False):
if not raw_line:
continue
line = raw_line.decode("utf-8")
if not line.startswith("data:"):
continue
if line.strip() == "data: [DONE]":
break
chunk = json.loads(line.removeprefix("data: ").strip())
delta = chunk["choices"][0]["delta"].get("content")
if delta:
yield delta
# ----------------------------- SQLite helpers ----------------------------- #
def ensure_schema(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
content TEXT NOT NULL,
embedding TEXT NOT NULL
)
"""
)
conn.commit()
def upsert_chunk(
conn: sqlite3.Connection,
*,
path: str,
chunk_index: int,
content: str,
embedding: Sequence[float],
) -> None:
conn.execute(
"""
INSERT INTO documents (path, chunk_index, content, embedding)
VALUES (?, ?, ?, ?)
""",
(path, chunk_index, content, json.dumps(list(embedding))),
)
def load_all_chunks(conn: sqlite3.Connection) -> List[Tuple[int, str, int, str, List[float]]]:
rows = conn.execute("SELECT id, path, chunk_index, content, embedding FROM documents").fetchall()
return [
(row[0], row[1], row[2], row[3], json.loads(row[4]))
for row in rows
]
# ----------------------------- Text utilities ----------------------------- #
def iter_txt_files(root: Path) -> Iterable[Path]:
for path in sorted(root.rglob("*.txt")):
if path.is_file():
yield path
def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]:
if chunk_size <= 0:
raise ValueError("chunk_size 必须大于 0")
chunks: List[str] = []
start = 0
length = len(text)
while start < length:
end = min(start + chunk_size, length)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start += chunk_size - overlap
if start < 0:
break
return chunks
def cosine_similarity(vec_a: Sequence[float], vec_b: Sequence[float]) -> float:
if len(vec_a) != len(vec_b):
raise ValueError("向量维度不一致")
dot = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
# ----------------------------- RAG pipeline ----------------------------- #
def build_knowledge_base(source_dir: Path, chunk_size: int = 400, overlap: int = 50) -> None:
txt_files = list(iter_txt_files(source_dir))
if not txt_files:
raise RuntimeError(f"目录 {source_dir} 下未找到 txt 文件")
conn = sqlite3.connect(DB_PATH)
ensure_schema(conn)
for file_path in txt_files:
text = file_path.read_text(encoding="utf-8")
chunks = chunk_text(text, chunk_size=chunk_size, overlap=overlap)
for idx, chunk in enumerate(chunks):
embedding = fetch_embedding(chunk)
upsert_chunk(conn, path=str(file_path), chunk_index=idx, content=chunk, embedding=embedding)
conn.commit()
print(f"[索引] {file_path.name} chunk#{idx} 已写入")
conn.close()
print("知识库构建完成。")
def retrieve_contexts(
question: str,
top_k: int = 3,
score_threshold: float = 0.3,
) -> List[Tuple[float, str, str, int]]:
conn = sqlite3.connect(DB_PATH)
ensure_schema(conn)
chunks = load_all_chunks(conn)
conn.close()
if not chunks:
raise RuntimeError("知识库为空,请先执行 build 命令")
query_vec = fetch_embedding(question)
scored = [
(cosine_similarity(query_vec, embedding), content, path, chunk_idx)
for _, path, chunk_idx, content, embedding in chunks
]
scored.sort(key=lambda x: x[0], reverse=True)
filtered = [item for item in scored if item[0] >= score_threshold]
if not filtered:
return []
return filtered[:top_k]
def _prepare_prompt(
question: str,
top_k: int = 3,
score_threshold: float = 0.3,
) -> Tuple[List[dict], str]:
contexts = retrieve_contexts(question, top_k=top_k, score_threshold=score_threshold)
if not contexts:
return [], "检索未命中相关片段,请检查知识库或降低阈值。"
print("\n=== 检索命中 ===")
for idx, (score, content, path, chunk_idx) in enumerate(contexts, start=1):
print(f"[Top{idx}] 相似度:{score:.4f}")
print(f"来源:{path}(段落 #{chunk_idx})")
print(content)
print("-" * 40)
context_text = "\n\n".join(
f"[source={path}#chunk{chunk_idx}] {content}"
for _, content, path, chunk_idx in contexts
)
messages = [
{
"role": "system",
"content": (
"你是一位知识库问答助手。回答必须严格基于提供的片段,"
"并在涉及事实的句子末尾使用 [source=文件路径#chunk编号] 进行引用。"
"若片段不足以回答,请明确说明。"
),
},
{
"role": "user",
"content": (
f"已知文档片段:\n{context_text}\n\n"
f"问题:{question}\n"
"请基于片段回答,确保在相关句子后附上对应的 source 引用。"
),
},
]
return messages, ""
def answer_question(question: str, top_k: int = 3, score_threshold: float = 0.3) -> str:
messages, error = _prepare_prompt(question, top_k=top_k, score_threshold=score_threshold)
if not messages:
return error
print("\n=== 调用大模型 ===")
print(json.dumps(messages, ensure_ascii=False, indent=2))
return chat_completion(messages)
def answer_question_stream(
question: str,
top_k: int = 3,
score_threshold: float = 0.3,
temperature: float = 0.7,
) -> Iterable[str]:
messages, error = _prepare_prompt(question, top_k=top_k, score_threshold=score_threshold)
if not messages:
yield error
return
print("\n=== 调用大模型 ===")
print(json.dumps(messages, ensure_ascii=False, indent=2))
yield from chat_completion_stream(messages, temperature=temperature)
# ----------------------------- CLI entry ----------------------------- #
def interactive_cli() -> None:
print("=== 硅基流动 RAG 命令行 ===")
print("1. 知识库初始化(索引 doc 目录)")
print("2. 开始对话(RAG 问答)")
print("0. 退出")
while True:
choice = input("\n请选择操作 [1/2/0]:").strip()
if choice == "1":
source = input(f"文档目录(默认 {DOC_DIR}):").strip()
source_dir = Path(source) if source else DOC_DIR
chunk_size = input("切片长度(默认 400):").strip()
overlap = input("切片重叠长度(默认 50):").strip()
build_knowledge_base(
source_dir,
chunk_size=int(chunk_size) if chunk_size else 400,
overlap=int(overlap) if overlap else 50,
)
elif choice == "2":
question = input("请输入你的问题:").strip()
if not question:
print("问题不能为空。")
continue
top_k = input(f"召回片段数量(默认 {DEFAULT_TOP_K}):").strip()
threshold = input(f"相似度阈值(默认 {DEFAULT_THRESHOLD}):").strip()
print("\n===== 回答开始 =====")
for chunk in answer_question_stream(
question,
top_k=int(top_k) if top_k else DEFAULT_TOP_K,
score_threshold=float(threshold) if threshold else DEFAULT_THRESHOLD,
):
print(chunk, end="", flush=True)
print("\n===== 回答结束 =====")
elif choice == "0":
print("Bye~")
break
else:
print("无效选项,请重新选择。")
if __name__ == "__main__":
interactive_cli()