Skip to content

Commit 40918fc

Browse files
committed
add async score function to trustworthyrag module
1 parent 5801dc1 commit 40918fc

File tree

1 file changed

+111
-11
lines changed

1 file changed

+111
-11
lines changed

src/cleanlab_tlm/utils/rag.py

Lines changed: 111 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
_VALID_TLM_QUALITY_PRESETS_RAG,
4141
)
4242
from cleanlab_tlm.internal.exception_handling import handle_tlm_exceptions
43-
from cleanlab_tlm.internal.validation import tlm_score_process_response_and_kwargs, validate_rag_inputs
43+
from cleanlab_tlm.internal.validation import (
44+
tlm_score_process_response_and_kwargs,
45+
validate_rag_inputs,
46+
)
4447

4548
if TYPE_CHECKING:
4649
from collections.abc import Coroutine
@@ -115,8 +118,12 @@ def __init__(
115118
name=cast(str, eval_config[_TLM_EVAL_NAME_KEY]),
116119
criteria=cast(str, eval_config[_TLM_EVAL_CRITERIA_KEY]),
117120
query_identifier=eval_config.get(_TLM_EVAL_QUERY_IDENTIFIER_KEY),
118-
context_identifier=eval_config.get(_TLM_EVAL_CONTEXT_IDENTIFIER_KEY),
119-
response_identifier=eval_config.get(_TLM_EVAL_RESPONSE_IDENTIFIER_KEY),
121+
context_identifier=eval_config.get(
122+
_TLM_EVAL_CONTEXT_IDENTIFIER_KEY
123+
),
124+
response_identifier=eval_config.get(
125+
_TLM_EVAL_RESPONSE_IDENTIFIER_KEY
126+
),
120127
)
121128
for eval_config in _DEFAULT_EVALS
122129
]
@@ -164,10 +171,16 @@ def score(
164171
)
165172

166173
# Support constrain_outputs later
167-
processed_responses = tlm_score_process_response_and_kwargs(formatted_prompts, response, None, {})
174+
processed_responses = tlm_score_process_response_and_kwargs(
175+
formatted_prompts, response, None, {}
176+
)
168177

169178
# Check if we're handling a batch or a single item
170-
if isinstance(query, str) and isinstance(context, str) and isinstance(processed_responses, dict):
179+
if (
180+
isinstance(query, str)
181+
and isinstance(context, str)
182+
and isinstance(processed_responses, dict)
183+
):
171184
return self._event_loop.run_until_complete(
172185
self._score_async(
173186
response=processed_responses,
@@ -189,6 +202,74 @@ def score(
189202
)
190203
)
191204

205+
async def score_async(
206+
self,
207+
*,
208+
response: Union[str, Sequence[str]],
209+
query: Union[str, Sequence[str]],
210+
context: Union[str, Sequence[str]],
211+
prompt: Optional[Union[str, Sequence[str]]] = None,
212+
form_prompt: Optional[Callable[[str, str], str]] = None,
213+
) -> Union[TrustworthyRAGScore, list[TrustworthyRAGScore]]:
214+
"""
215+
Evaluate an existing RAG system's response to a given user query and retrieved context.
216+
217+
Args:
218+
response (str | Sequence[str]): A response (or list of multiple responses) from your LLM/RAG system.
219+
query (str | Sequence[str]): The user query (or list of multiple queries) that was used to generate the response.
220+
context (str | Sequence[str]): The context (or list of multiple contexts) that was retrieved from the RAG Knowledge Base and used to generate the response.
221+
prompt (str | Sequence[str], optional): Optional prompt (or list of multiple prompts) representing the actual inputs (combining query, context, and system instructions into one string) to the LLM that generated the response.
222+
form_prompt (Callable[[str, str], str], optional): Optional function to format the prompt based on query and context. Cannot be provided together with prompt, provide one or the other.
223+
This function should take query and context as parameters and return a formatted prompt string.
224+
If not provided, a default prompt formatter will be used.
225+
To include a system prompt or any other special instructions for your LLM,
226+
incorporate them directly in your custom `form_prompt()` function definition.
227+
228+
Returns:
229+
TrustworthyRAGScore | list[TrustworthyRAGScore]: [TrustworthyRAGScore](#class-trustworthyragscore) object containing evaluation metrics.
230+
If multiple inputs were provided in lists, a list of TrustworthyRAGScore objects is returned, one for each set of inputs.
231+
"""
232+
if prompt is None and form_prompt is None:
233+
form_prompt = TrustworthyRAG._default_prompt_formatter
234+
235+
formatted_prompts = validate_rag_inputs(
236+
query=query,
237+
context=context,
238+
response=response,
239+
prompt=prompt,
240+
form_prompt=form_prompt,
241+
evals=self._evals,
242+
is_generate=False,
243+
)
244+
245+
# Support constrain_outputs later
246+
processed_responses = tlm_score_process_response_and_kwargs(
247+
formatted_prompts, response, None, {}
248+
)
249+
250+
# Check if we're handling a batch or a single item
251+
if (
252+
isinstance(query, str)
253+
and isinstance(context, str)
254+
and isinstance(processed_responses, dict)
255+
):
256+
return await self._score_async(
257+
response=processed_responses,
258+
prompt=formatted_prompts,
259+
query=query,
260+
context=context,
261+
timeout=self._timeout,
262+
)
263+
264+
# Batch processing
265+
return await self._batch_score(
266+
responses=cast(Sequence[dict[str, Any]], processed_responses),
267+
prompts=formatted_prompts,
268+
queries=query,
269+
contexts=context,
270+
capture_exceptions=False,
271+
)
272+
192273
def generate(
193274
self,
194275
*,
@@ -212,11 +293,20 @@ def generate(
212293
form_prompt = TrustworthyRAG._default_prompt_formatter
213294

214295
formatted_prompts = validate_rag_inputs(
215-
query=query, context=context, prompt=prompt, form_prompt=form_prompt, evals=self._evals, is_generate=True
296+
query=query,
297+
context=context,
298+
prompt=prompt,
299+
form_prompt=form_prompt,
300+
evals=self._evals,
301+
is_generate=True,
216302
)
217303

218304
# Check if we're handling a batch or a single item
219-
if isinstance(query, str) and isinstance(context, str) and isinstance(formatted_prompts, str):
305+
if (
306+
isinstance(query, str)
307+
and isinstance(context, str)
308+
and isinstance(formatted_prompts, str)
309+
):
220310
return self._event_loop.run_until_complete(
221311
self._generate_async(
222312
prompt=formatted_prompts,
@@ -287,7 +377,9 @@ async def _batch_generate(
287377
capture_exceptions=capture_exceptions,
288378
batch_index=batch_index,
289379
)
290-
for batch_index, (prompt, query, context) in enumerate(zip(prompts, queries, contexts))
380+
for batch_index, (prompt, query, context) in enumerate(
381+
zip(prompts, queries, contexts)
382+
)
291383
],
292384
per_batch_timeout,
293385
)
@@ -344,7 +436,9 @@ async def _batch_score(
344436

345437
async def _batch_async(
346438
self,
347-
rag_coroutines: Sequence[Coroutine[None, None, Union[TrustworthyRAGResponse, TrustworthyRAGScore]]],
439+
rag_coroutines: Sequence[
440+
Coroutine[None, None, Union[TrustworthyRAGResponse, TrustworthyRAGScore]]
441+
],
348442
batch_timeout: Optional[float] = None,
349443
) -> Sequence[Union[TrustworthyRAGResponse, TrustworthyRAGScore]]:
350444
"""Runs batch of TrustworthyRAG operations.
@@ -516,7 +610,9 @@ def _default_prompt_formatter(query: str, context: str) -> str:
516610
prompt_parts.append("---------------------\n")
517611

518612
# Add instruction to use context
519-
prompt_parts.append("Using the context information provided above, please answer the following question:\n")
613+
prompt_parts.append(
614+
"Using the context information provided above, please answer the following question:\n"
615+
)
520616

521617
# Add user query
522618
prompt_parts.append(f"User: {query.strip()}\n")
@@ -557,7 +653,11 @@ def __init__(
557653
lazydocs: ignore
558654
"""
559655
# Validate that at least one identifier is specified
560-
if query_identifier is None and context_identifier is None and response_identifier is None:
656+
if (
657+
query_identifier is None
658+
and context_identifier is None
659+
and response_identifier is None
660+
):
561661
raise ValueError(
562662
"At least one of query_identifier, context_identifier, or response_identifier must be specified."
563663
)

0 commit comments

Comments
 (0)