Skip to content

Commit 0984435

Browse files
authored
testset generation: bug fixes (#185)
Fixes - [x] issues with multi-context question generation - [x] Error in doc filtering
1 parent 6787a5c commit 0984435

File tree

1 file changed

+49
-42
lines changed

1 file changed

+49
-42
lines changed

src/ragas/testset/testset_generator.py

+49-42
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import re
12
import typing as t
3+
import warnings
24
from collections import defaultdict, namedtuple
35
from dataclasses import dataclass
46

@@ -33,10 +35,10 @@
3335
)
3436

3537
DEFAULT_TEST_DISTRIBUTION = {
36-
"simple": 0.5,
38+
"simple": 0.4,
3739
"reasoning": 0.2,
3840
"multi_context": 0.2,
39-
"conditional": 0.1,
41+
"conditional": 0.2,
4042
}
4143

4244
question_deep_map = {
@@ -106,7 +108,7 @@ def __init__(
106108
critic_llm: BaseLLM | BaseChatModel,
107109
embeddings_model: Embeddings,
108110
testset_distribution: t.Optional[t.Dict[str, float]] = None,
109-
chat_qa: float = 0.3,
111+
chat_qa: float = 0.0,
110112
chunk_size: int = 1024,
111113
seed: int = 42,
112114
) -> None:
@@ -135,7 +137,7 @@ def from_default(
135137
openai_generator_llm: str = "gpt-3.5-turbo-16k",
136138
openai_filter_llm: str = "gpt-4",
137139
chat_qa: float = 0.3,
138-
chunk_size: int = 1024,
140+
chunk_size: int = 512,
139141
):
140142
generator_llm = ChatOpenAI(model=openai_generator_llm)
141143
critic_llm = ChatOpenAI(model=openai_filter_llm)
@@ -173,14 +175,12 @@ def _filter_context(self, context: str) -> bool:
173175
prompt = ChatPromptTemplate.from_messages([human_prompt])
174176
results = generate(prompts=[prompt], llm=self.critic_llm)
175177
output = results.generations[0][0].text.strip()
176-
score = eval(output)
177-
if not isinstance(score, float | int):
178-
index = output.lower().find("score:")
179-
if index != -1:
180-
index += len("score:")
181-
score = eval(output[index:])
182-
else:
183-
score = 0.0
178+
pattern = r"^[\d.]+$"
179+
if not re.match(pattern, output):
180+
score = 0.0
181+
else:
182+
score = eval(output)
183+
184184
return score >= self.threshold
185185

186186
def _seed_question(self, context: str) -> str:
@@ -241,22 +241,30 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]:
241241
for qstn in question.split("\n")
242242
]
243243

244-
def _remove_index(self, available_indices: list, node_idx: list) -> t.List:
244+
def _remove_nodes(self, available_indices: list, node_idx: list) -> t.List:
245245
for idx in node_idx:
246246
available_indices.remove(idx)
247247
return available_indices
248248

249-
def _generate_doc_node_map(
249+
def _generate_doc_nodes_map(
250250
self, documenet_nodes: t.List[BaseNode]
251-
) -> t.Dict[str, list]:
252-
doc_nodeidx = defaultdict(list)
253-
for idx, node in enumerate(documenet_nodes):
254-
doc_nodeidx[node.id_].append(idx)
255-
256-
return doc_nodeidx
257-
258-
def _get_neighbour_node(self, idx: int, node_indices: list) -> t.List[int]:
259-
return [idx - 1, idx] if idx == node_indices[-1] else [idx, idx + 1]
251+
) -> t.Dict[str, BaseNode]:
252+
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list[BaseNode])
253+
for node in documenet_nodes:
254+
if node.ref_doc_id:
255+
doc_nodes_map[node.ref_doc_id].append(node)
256+
257+
return doc_nodes_map # type: ignore
258+
259+
def _get_neighbour_node(
260+
self, node: BaseNode, related_nodes: list[BaseNode]
261+
) -> t.List[BaseNode]:
262+
if len(related_nodes) < 2:
263+
warnings.warn("No neighbors exists")
264+
return [node]
265+
idx = related_nodes.index(node)
266+
ids = [idx - 1, idx] if idx == (len(related_nodes) - 1) else [idx, idx + 1]
267+
return [related_nodes[idx] for idx in ids]
260268

261269
def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]:
262270
embeddings = {}
@@ -275,38 +283,38 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
275283
document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents(
276284
documents=documents
277285
)
278-
279286
# maximum 1 seed question per node
280287
if test_size > len(document_nodes):
281288
raise ValueError(
282289
"""Maximum possible number of samples exceeded,
283290
reduce test_size or add more documents"""
284291
)
285292

286-
available_indices = np.arange(0, len(document_nodes)).tolist()
287-
doc_nodeidx = self._generate_doc_node_map(document_nodes)
293+
available_nodes = document_nodes
294+
doc_nodes_map = self._generate_doc_nodes_map(document_nodes)
295+
count_neighbours = sum(len(val) > 1 for _, val in doc_nodes_map.items())
296+
if count_neighbours < len(documents) // 2:
297+
warnings.warn("Most documents are too short")
298+
288299
count = 0
289300
samples = []
290301

291302
pbar = tqdm(total=test_size)
292-
while count < test_size and available_indices != []:
303+
while count < test_size and available_nodes != []:
293304
evolve_type = self._get_evolve_type()
294-
node_idx = self.rng.choice(available_indices, size=1)[0]
295-
available_indices = self._remove_index(available_indices, [node_idx])
305+
curr_node = self.rng.choice(available_nodes, size=1)[0]
306+
available_nodes = self._remove_nodes(available_nodes, [curr_node])
296307

297-
neighbor_nodes = doc_nodeidx[
298-
document_nodes[node_idx].node_id # type: ignore
299-
]
308+
neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id]
300309

301310
# Append multiple nodes randomly to remove chunking bias
302311
size = self.rng.integers(1, 3)
303-
node_indices = (
304-
self._get_neighbour_node(node_idx, neighbor_nodes)
312+
nodes = (
313+
self._get_neighbour_node(curr_node, neighbor_nodes)
305314
if size > 1 and evolve_type != "multi_context"
306-
else [node_idx]
315+
else [curr_node]
307316
)
308317

309-
nodes = [document_nodes[node_idx] for node_idx in node_indices]
310318
text_chunk = " ".join([node.get_content() for node in nodes])
311319
score = self._filter_context(text_chunk)
312320
if not score:
@@ -316,14 +324,13 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
316324
if evolve_type == "multi_context":
317325
# Find most similar chunk in same document
318326
node_embedding = self._embed_nodes([nodes[-1]])
319-
neighbor_nodes = self._remove_index(neighbor_nodes, node_indices)
320-
neighbor_emb = self._embed_nodes(
321-
[document_nodes[idx][0] for idx in neighbor_nodes]
322-
)
327+
neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes)
328+
neighbor_emb = self._embed_nodes(neighbor_nodes)
329+
323330
_, indices = get_top_k_embeddings(
324331
list(node_embedding.values())[0],
325332
list(neighbor_emb.values()),
326-
similarity_cutoff=self.threshold,
333+
similarity_cutoff=self.threshold / 10,
327334
)
328335
if indices:
329336
best_neighbor = neighbor_nodes[indices[0]]
@@ -332,7 +339,7 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
332339
context1=text_chunk,
333340
context2=best_neighbor.get_content(),
334341
)
335-
text_chunk = "\n".join([text_chunk, best_neighbor.get_context()])
342+
text_chunk = "\n".join([text_chunk, best_neighbor.get_content()])
336343
else:
337344
continue
338345

0 commit comments

Comments
 (0)