|
| 1 | +from huggingface_hub import InferenceClient |
| 2 | + |
| 3 | +from datachain import C, DataChain, DataModel |
| 4 | + |
| 5 | +PROMPT = """ |
| 6 | +Was this dialog successful? Put result as a single word: Success or Failure. |
| 7 | +Explain the reason in a few words. |
| 8 | +""" |
| 9 | + |
| 10 | + |
| 11 | +class DialogEval(DataModel): |
| 12 | + result: str |
| 13 | + reason: str |
| 14 | + |
| 15 | + |
| 16 | +# DataChain function to evaluate dialog. |
| 17 | +# DataChain is using types for inputs, results to automatically infer schema. |
| 18 | +def eval_dialog(user_input: str, bot_response: str) -> DialogEval: |
| 19 | + client = InferenceClient("meta-llama/Llama-3.1-70B-Instruct") |
| 20 | + |
| 21 | + completion = client.chat_completion( |
| 22 | + messages=[ |
| 23 | + { |
| 24 | + "role": "user", |
| 25 | + "content": f"{PROMPT}\n\nUser: {user_input}\nBot: {bot_response}", |
| 26 | + }, |
| 27 | + ], |
| 28 | + response_format={"type": "json", "value": DialogEval.model_json_schema()}, |
| 29 | + ) |
| 30 | + |
| 31 | + message = completion.choices[0].message |
| 32 | + try: |
| 33 | + return DialogEval.model_validate_json(message.content) |
| 34 | + except ValueError: |
| 35 | + return DialogEval(result="Error", reason="Failed to parse response.") |
| 36 | + |
| 37 | + |
| 38 | +# Run HF inference in parallel for each example. |
| 39 | +# Get result as Pydantic model that DataChain can understand and serialize it. |
| 40 | +# Save to HF as Parquet. Dataset can be previewed here: |
| 41 | +# https://huggingface.co/datasets/dvcorg/test-datachain-llm-eval/viewer |
| 42 | +( |
| 43 | + DataChain.from_csv( |
| 44 | + "hf://datasets/infinite-dataset-hub/MobilePlanAssistant/data.csv" |
| 45 | + ) |
| 46 | + .settings(parallel=10) |
| 47 | + .map(response=eval_dialog) |
| 48 | + .to_parquet("hf://datasets/dvcorg/test-datachain-llm-eval/data.parquet") |
| 49 | +) |
| 50 | + |
| 51 | +# Read it back to filter and show. |
| 52 | +# It restores the Pydantic model from Parquet under the hood. |
| 53 | +( |
| 54 | + DataChain.from_parquet( |
| 55 | + "hf://datasets/dvcorg/test-datachain-llm-eval/data.parquet", source=False |
| 56 | + ) |
| 57 | + .filter(C("response.result") == "Failure") |
| 58 | + .show(3) |
| 59 | +) |
0 commit comments