1
+ import re
1
2
import typing as t
3
+ import warnings
2
4
from collections import defaultdict , namedtuple
3
5
from dataclasses import dataclass
4
6
33
35
)
34
36
35
37
DEFAULT_TEST_DISTRIBUTION = {
36
- "simple" : 0.5 ,
38
+ "simple" : 0.4 ,
37
39
"reasoning" : 0.2 ,
38
40
"multi_context" : 0.2 ,
39
- "conditional" : 0.1 ,
41
+ "conditional" : 0.2 ,
40
42
}
41
43
42
44
question_deep_map = {
@@ -106,7 +108,7 @@ def __init__(
106
108
critic_llm : BaseLLM | BaseChatModel ,
107
109
embeddings_model : Embeddings ,
108
110
testset_distribution : t .Optional [t .Dict [str , float ]] = None ,
109
- chat_qa : float = 0.3 ,
111
+ chat_qa : float = 0.0 ,
110
112
chunk_size : int = 1024 ,
111
113
seed : int = 42 ,
112
114
) -> None :
@@ -135,7 +137,7 @@ def from_default(
135
137
openai_generator_llm : str = "gpt-3.5-turbo-16k" ,
136
138
openai_filter_llm : str = "gpt-4" ,
137
139
chat_qa : float = 0.3 ,
138
- chunk_size : int = 1024 ,
140
+ chunk_size : int = 512 ,
139
141
):
140
142
generator_llm = ChatOpenAI (model = openai_generator_llm )
141
143
critic_llm = ChatOpenAI (model = openai_filter_llm )
@@ -173,14 +175,12 @@ def _filter_context(self, context: str) -> bool:
173
175
prompt = ChatPromptTemplate .from_messages ([human_prompt ])
174
176
results = generate (prompts = [prompt ], llm = self .critic_llm )
175
177
output = results .generations [0 ][0 ].text .strip ()
176
- score = eval (output )
177
- if not isinstance (score , float | int ):
178
- index = output .lower ().find ("score:" )
179
- if index != - 1 :
180
- index += len ("score:" )
181
- score = eval (output [index :])
182
- else :
183
- score = 0.0
178
+ pattern = r"^[\d.]+$"
179
+ if not re .match (pattern , output ):
180
+ score = 0.0
181
+ else :
182
+ score = eval (output )
183
+
184
184
return score >= self .threshold
185
185
186
186
def _seed_question (self , context : str ) -> str :
@@ -241,22 +241,30 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]:
241
241
for qstn in question .split ("\n " )
242
242
]
243
243
244
- def _remove_index (self , available_indices : list , node_idx : list ) -> t .List :
244
+ def _remove_nodes (self , available_indices : list , node_idx : list ) -> t .List :
245
245
for idx in node_idx :
246
246
available_indices .remove (idx )
247
247
return available_indices
248
248
249
- def _generate_doc_node_map (
249
+ def _generate_doc_nodes_map (
250
250
self , documenet_nodes : t .List [BaseNode ]
251
- ) -> t .Dict [str , list ]:
252
- doc_nodeidx = defaultdict (list )
253
- for idx , node in enumerate (documenet_nodes ):
254
- doc_nodeidx [node .id_ ].append (idx )
255
-
256
- return doc_nodeidx
257
-
258
- def _get_neighbour_node (self , idx : int , node_indices : list ) -> t .List [int ]:
259
- return [idx - 1 , idx ] if idx == node_indices [- 1 ] else [idx , idx + 1 ]
251
+ ) -> t .Dict [str , BaseNode ]:
252
+ doc_nodes_map : t .Dict [str , t .List [BaseNode ]] = defaultdict (list [BaseNode ])
253
+ for node in documenet_nodes :
254
+ if node .ref_doc_id :
255
+ doc_nodes_map [node .ref_doc_id ].append (node )
256
+
257
+ return doc_nodes_map # type: ignore
258
+
259
+ def _get_neighbour_node (
260
+ self , node : BaseNode , related_nodes : list [BaseNode ]
261
+ ) -> t .List [BaseNode ]:
262
+ if len (related_nodes ) < 2 :
263
+ warnings .warn ("No neighbors exists" )
264
+ return [node ]
265
+ idx = related_nodes .index (node )
266
+ ids = [idx - 1 , idx ] if idx == (len (related_nodes ) - 1 ) else [idx , idx + 1 ]
267
+ return [related_nodes [idx ] for idx in ids ]
260
268
261
269
def _embed_nodes (self , nodes : t .List [BaseNode ]) -> t .Dict [str , t .List [float ]]:
262
270
embeddings = {}
@@ -275,38 +283,38 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
275
283
document_nodes : t .List [BaseNode ] = node_parser .get_nodes_from_documents (
276
284
documents = documents
277
285
)
278
-
279
286
# maximum 1 seed question per node
280
287
if test_size > len (document_nodes ):
281
288
raise ValueError (
282
289
"""Maximum possible number of samples exceeded,
283
290
reduce test_size or add more documents"""
284
291
)
285
292
286
- available_indices = np .arange (0 , len (document_nodes )).tolist ()
287
- doc_nodeidx = self ._generate_doc_node_map (document_nodes )
293
+ available_nodes = document_nodes
294
+ doc_nodes_map = self ._generate_doc_nodes_map (document_nodes )
295
+ count_neighbours = sum (len (val ) > 1 for _ , val in doc_nodes_map .items ())
296
+ if count_neighbours < len (documents ) // 2 :
297
+ warnings .warn ("Most documents are too short" )
298
+
288
299
count = 0
289
300
samples = []
290
301
291
302
pbar = tqdm (total = test_size )
292
- while count < test_size and available_indices != []:
303
+ while count < test_size and available_nodes != []:
293
304
evolve_type = self ._get_evolve_type ()
294
- node_idx = self .rng .choice (available_indices , size = 1 )[0 ]
295
- available_indices = self ._remove_index ( available_indices , [node_idx ])
305
+ curr_node = self .rng .choice (available_nodes , size = 1 )[0 ]
306
+ available_nodes = self ._remove_nodes ( available_nodes , [curr_node ])
296
307
297
- neighbor_nodes = doc_nodeidx [
298
- document_nodes [node_idx ].node_id # type: ignore
299
- ]
308
+ neighbor_nodes = doc_nodes_map [curr_node .source_node .node_id ]
300
309
301
310
# Append multiple nodes randomly to remove chunking bias
302
311
size = self .rng .integers (1 , 3 )
303
- node_indices = (
304
- self ._get_neighbour_node (node_idx , neighbor_nodes )
312
+ nodes = (
313
+ self ._get_neighbour_node (curr_node , neighbor_nodes )
305
314
if size > 1 and evolve_type != "multi_context"
306
- else [node_idx ]
315
+ else [curr_node ]
307
316
)
308
317
309
- nodes = [document_nodes [node_idx ] for node_idx in node_indices ]
310
318
text_chunk = " " .join ([node .get_content () for node in nodes ])
311
319
score = self ._filter_context (text_chunk )
312
320
if not score :
@@ -316,14 +324,13 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
316
324
if evolve_type == "multi_context" :
317
325
# Find most similar chunk in same document
318
326
node_embedding = self ._embed_nodes ([nodes [- 1 ]])
319
- neighbor_nodes = self ._remove_index (neighbor_nodes , node_indices )
320
- neighbor_emb = self ._embed_nodes (
321
- [document_nodes [idx ][0 ] for idx in neighbor_nodes ]
322
- )
327
+ neighbor_nodes = self ._remove_nodes (neighbor_nodes , nodes )
328
+ neighbor_emb = self ._embed_nodes (neighbor_nodes )
329
+
323
330
_ , indices = get_top_k_embeddings (
324
331
list (node_embedding .values ())[0 ],
325
332
list (neighbor_emb .values ()),
326
- similarity_cutoff = self .threshold ,
333
+ similarity_cutoff = self .threshold / 10 ,
327
334
)
328
335
if indices :
329
336
best_neighbor = neighbor_nodes [indices [0 ]]
@@ -332,7 +339,7 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
332
339
context1 = text_chunk ,
333
340
context2 = best_neighbor .get_content (),
334
341
)
335
- text_chunk = "\n " .join ([text_chunk , best_neighbor .get_context ()])
342
+ text_chunk = "\n " .join ([text_chunk , best_neighbor .get_content ()])
336
343
else :
337
344
continue
338
345
0 commit comments