40
40
_VALID_TLM_QUALITY_PRESETS_RAG ,
41
41
)
42
42
from cleanlab_tlm .internal .exception_handling import handle_tlm_exceptions
43
- from cleanlab_tlm .internal .validation import tlm_score_process_response_and_kwargs , validate_rag_inputs
43
+ from cleanlab_tlm .internal .validation import (
44
+ tlm_score_process_response_and_kwargs ,
45
+ validate_rag_inputs ,
46
+ )
44
47
45
48
if TYPE_CHECKING :
46
49
from collections .abc import Coroutine
@@ -115,8 +118,12 @@ def __init__(
115
118
name = cast (str , eval_config [_TLM_EVAL_NAME_KEY ]),
116
119
criteria = cast (str , eval_config [_TLM_EVAL_CRITERIA_KEY ]),
117
120
query_identifier = eval_config .get (_TLM_EVAL_QUERY_IDENTIFIER_KEY ),
118
- context_identifier = eval_config .get (_TLM_EVAL_CONTEXT_IDENTIFIER_KEY ),
119
- response_identifier = eval_config .get (_TLM_EVAL_RESPONSE_IDENTIFIER_KEY ),
121
+ context_identifier = eval_config .get (
122
+ _TLM_EVAL_CONTEXT_IDENTIFIER_KEY
123
+ ),
124
+ response_identifier = eval_config .get (
125
+ _TLM_EVAL_RESPONSE_IDENTIFIER_KEY
126
+ ),
120
127
)
121
128
for eval_config in _DEFAULT_EVALS
122
129
]
@@ -164,10 +171,16 @@ def score(
164
171
)
165
172
166
173
# Support constrain_outputs later
167
- processed_responses = tlm_score_process_response_and_kwargs (formatted_prompts , response , None , {})
174
+ processed_responses = tlm_score_process_response_and_kwargs (
175
+ formatted_prompts , response , None , {}
176
+ )
168
177
169
178
# Check if we're handling a batch or a single item
170
- if isinstance (query , str ) and isinstance (context , str ) and isinstance (processed_responses , dict ):
179
+ if (
180
+ isinstance (query , str )
181
+ and isinstance (context , str )
182
+ and isinstance (processed_responses , dict )
183
+ ):
171
184
return self ._event_loop .run_until_complete (
172
185
self ._score_async (
173
186
response = processed_responses ,
@@ -189,6 +202,74 @@ def score(
189
202
)
190
203
)
191
204
205
+ async def score_async (
206
+ self ,
207
+ * ,
208
+ response : Union [str , Sequence [str ]],
209
+ query : Union [str , Sequence [str ]],
210
+ context : Union [str , Sequence [str ]],
211
+ prompt : Optional [Union [str , Sequence [str ]]] = None ,
212
+ form_prompt : Optional [Callable [[str , str ], str ]] = None ,
213
+ ) -> Union [TrustworthyRAGScore , list [TrustworthyRAGScore ]]:
214
+ """
215
+ Evaluate an existing RAG system's response to a given user query and retrieved context.
216
+
217
+ Args:
218
+ response (str | Sequence[str]): A response (or list of multiple responses) from your LLM/RAG system.
219
+ query (str | Sequence[str]): The user query (or list of multiple queries) that was used to generate the response.
220
+ context (str | Sequence[str]): The context (or list of multiple contexts) that was retrieved from the RAG Knowledge Base and used to generate the response.
221
+ prompt (str | Sequence[str], optional): Optional prompt (or list of multiple prompts) representing the actual inputs (combining query, context, and system instructions into one string) to the LLM that generated the response.
222
+ form_prompt (Callable[[str, str], str], optional): Optional function to format the prompt based on query and context. Cannot be provided together with prompt, provide one or the other.
223
+ This function should take query and context as parameters and return a formatted prompt string.
224
+ If not provided, a default prompt formatter will be used.
225
+ To include a system prompt or any other special instructions for your LLM,
226
+ incorporate them directly in your custom `form_prompt()` function definition.
227
+
228
+ Returns:
229
+ TrustworthyRAGScore | list[TrustworthyRAGScore]: [TrustworthyRAGScore](#class-trustworthyragscore) object containing evaluation metrics.
230
+ If multiple inputs were provided in lists, a list of TrustworthyRAGScore objects is returned, one for each set of inputs.
231
+ """
232
+ if prompt is None and form_prompt is None :
233
+ form_prompt = TrustworthyRAG ._default_prompt_formatter
234
+
235
+ formatted_prompts = validate_rag_inputs (
236
+ query = query ,
237
+ context = context ,
238
+ response = response ,
239
+ prompt = prompt ,
240
+ form_prompt = form_prompt ,
241
+ evals = self ._evals ,
242
+ is_generate = False ,
243
+ )
244
+
245
+ # Support constrain_outputs later
246
+ processed_responses = tlm_score_process_response_and_kwargs (
247
+ formatted_prompts , response , None , {}
248
+ )
249
+
250
+ # Check if we're handling a batch or a single item
251
+ if (
252
+ isinstance (query , str )
253
+ and isinstance (context , str )
254
+ and isinstance (processed_responses , dict )
255
+ ):
256
+ return await self ._score_async (
257
+ response = processed_responses ,
258
+ prompt = formatted_prompts ,
259
+ query = query ,
260
+ context = context ,
261
+ timeout = self ._timeout ,
262
+ )
263
+
264
+ # Batch processing
265
+ return await self ._batch_score (
266
+ responses = cast (Sequence [dict [str , Any ]], processed_responses ),
267
+ prompts = formatted_prompts ,
268
+ queries = query ,
269
+ contexts = context ,
270
+ capture_exceptions = False ,
271
+ )
272
+
192
273
def generate (
193
274
self ,
194
275
* ,
@@ -212,11 +293,20 @@ def generate(
212
293
form_prompt = TrustworthyRAG ._default_prompt_formatter
213
294
214
295
formatted_prompts = validate_rag_inputs (
215
- query = query , context = context , prompt = prompt , form_prompt = form_prompt , evals = self ._evals , is_generate = True
296
+ query = query ,
297
+ context = context ,
298
+ prompt = prompt ,
299
+ form_prompt = form_prompt ,
300
+ evals = self ._evals ,
301
+ is_generate = True ,
216
302
)
217
303
218
304
# Check if we're handling a batch or a single item
219
- if isinstance (query , str ) and isinstance (context , str ) and isinstance (formatted_prompts , str ):
305
+ if (
306
+ isinstance (query , str )
307
+ and isinstance (context , str )
308
+ and isinstance (formatted_prompts , str )
309
+ ):
220
310
return self ._event_loop .run_until_complete (
221
311
self ._generate_async (
222
312
prompt = formatted_prompts ,
@@ -287,7 +377,9 @@ async def _batch_generate(
287
377
capture_exceptions = capture_exceptions ,
288
378
batch_index = batch_index ,
289
379
)
290
- for batch_index , (prompt , query , context ) in enumerate (zip (prompts , queries , contexts ))
380
+ for batch_index , (prompt , query , context ) in enumerate (
381
+ zip (prompts , queries , contexts )
382
+ )
291
383
],
292
384
per_batch_timeout ,
293
385
)
@@ -344,7 +436,9 @@ async def _batch_score(
344
436
345
437
async def _batch_async (
346
438
self ,
347
- rag_coroutines : Sequence [Coroutine [None , None , Union [TrustworthyRAGResponse , TrustworthyRAGScore ]]],
439
+ rag_coroutines : Sequence [
440
+ Coroutine [None , None , Union [TrustworthyRAGResponse , TrustworthyRAGScore ]]
441
+ ],
348
442
batch_timeout : Optional [float ] = None ,
349
443
) -> Sequence [Union [TrustworthyRAGResponse , TrustworthyRAGScore ]]:
350
444
"""Runs batch of TrustworthyRAG operations.
@@ -516,7 +610,9 @@ def _default_prompt_formatter(query: str, context: str) -> str:
516
610
prompt_parts .append ("---------------------\n " )
517
611
518
612
# Add instruction to use context
519
- prompt_parts .append ("Using the context information provided above, please answer the following question:\n " )
613
+ prompt_parts .append (
614
+ "Using the context information provided above, please answer the following question:\n "
615
+ )
520
616
521
617
# Add user query
522
618
prompt_parts .append (f"User: { query .strip ()} \n " )
@@ -557,7 +653,11 @@ def __init__(
557
653
lazydocs: ignore
558
654
"""
559
655
# Validate that at least one identifier is specified
560
- if query_identifier is None and context_identifier is None and response_identifier is None :
656
+ if (
657
+ query_identifier is None
658
+ and context_identifier is None
659
+ and response_identifier is None
660
+ ):
561
661
raise ValueError (
562
662
"At least one of query_identifier, context_identifier, or response_identifier must be specified."
563
663
)
0 commit comments