Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lightspeed_core_evaluation/driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Driver for evaluation."""

import argparse
import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path

from httpx import Client
Expand All @@ -15,9 +15,9 @@
from .eval_run_common import add_common_arguments


def _args_parser(args) -> argparse.ArgumentParser:
def _args_parser(args: list[str]) -> Namespace:
"""Arguments parser."""
parser = argparse.ArgumentParser(description="Response validation module.")
parser = ArgumentParser(description="Response validation module.")
# Add arguments common to all eval scripts
add_common_arguments(parser)

Expand Down
19 changes: 10 additions & 9 deletions src/lightspeed_core_evaluation/rag_eval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""RAG Evaluation."""

import argparse
import json
import os
import sys
from argparse import ArgumentParser, Namespace
from datetime import UTC, datetime
from time import sleep
from typing import Any

from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from numpy import argsort, array
from ols import config
from pandas import read_parquet
from pandas import DataFrame, read_parquet
from tqdm import tqdm

from .eval_run_common import add_common_arguments
Expand All @@ -30,9 +31,9 @@
tqdm.pandas()


def _args_parser(args):
def _args_parser(args: list[str]) -> Namespace:
"""Arguments parser."""
parser = argparse.ArgumentParser(description="RAG evaluation module.")
parser = ArgumentParser(description="RAG evaluation module.")
# Add arguments common to all eval scripts
add_common_arguments(parser)

Expand All @@ -58,7 +59,7 @@ def _args_parser(args):
class RetrievalEvaluation: # pylint: disable=R0903
"""Evaluate Retrieval."""

def __init__(self, eval_args) -> None:
def __init__(self, eval_args: Namespace) -> None:
"""Initialize."""
print(f"Arguments: {eval_args}")
self._args = eval_args
Expand Down Expand Up @@ -104,7 +105,7 @@ def _set_directories(self) -> tuple[str, str]:
os.makedirs(result_dir, exist_ok=True)
return input_dir, result_dir

def _load_qna_pool_parquet(self):
def _load_qna_pool_parquet(self) -> DataFrame:
"""Load QnA pool from parquet file."""
input_file = self._args.qna_pool_file
if not input_file:
Expand All @@ -123,7 +124,7 @@ def _load_qna_pool_parquet(self):
].reset_index(drop=True)
return qna_pool_df

def _load_and_process_chunks(self, query) -> str:
def _load_and_process_chunks(self, query: str) -> str:
"""Load and process chunks."""
nodes = self._retriever.retrieve(query)
chunks = [
Expand All @@ -133,7 +134,7 @@ def _load_and_process_chunks(self, query) -> str:
]
return "\n\n".join(chunks)

def _get_judge_response(self, query):
def _get_judge_response(self, query: str) -> dict[str, Any]:
"""Get Judge response."""
print("Getting Judge response...")
result = {}
Expand All @@ -156,7 +157,7 @@ def _get_judge_response(self, query):

return result

def _process_score(self, score_data) -> float:
def _process_score(self, score_data: dict[str, Any]) -> float:
"""Process score."""
relevance_score = array(score_data["relevance_score"])
completeness_score = array(score_data["completeness_score"])
Expand Down
53 changes: 31 additions & 22 deletions src/lightspeed_core_evaluation/response_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import json
import os
from argparse import Namespace
from collections import defaultdict
from datetime import UTC, datetime
from time import sleep

from httpx import Client
from ols import config
from pandas import DataFrame, concat, read_csv, read_parquet
from tqdm import tqdm
Expand Down Expand Up @@ -33,7 +35,7 @@
class ResponseEvaluation: # pylint: disable=R0902
"""Evaluate LLM response."""

def __init__(self, eval_args, api_client):
def __init__(self, eval_args: Namespace, api_client: Client) -> None:
"""Initialize."""
print(f"Response evaluation arguments: {eval_args}")
self._args = eval_args
Expand Down Expand Up @@ -83,7 +85,7 @@ def _load_config_and_rag(self) -> None:
if config.rag_index is None:
raise RuntimeError("No valid rag index for ols_rag mode")

def _set_directories(self):
def _set_directories(self) -> tuple[str, str]:
"""Set input/output directories.""" # pylint: disable=R0801
eval_dir = os.path.dirname(__file__)
input_dir = os.path.join(eval_dir, DEFAULT_INPUT_DIR)
Expand All @@ -94,7 +96,7 @@ def _set_directories(self):
os.makedirs(result_dir, exist_ok=True)
return input_dir, result_dir

def _load_qna_pool_parquet(self):
def _load_qna_pool_parquet(self) -> DataFrame:
"""Load QnA pool from parquet file."""
qna_pool_df = DataFrame()
if self._args.qna_pool_file is not None:
Expand All @@ -109,7 +111,7 @@ def _load_qna_pool_parquet(self):
qna_pool_df["in_use"] = True
return qna_pool_df

def _restructure_qna_pool_json(self, provider_model_id):
def _restructure_qna_pool_json(self, provider_model_id: str) -> DataFrame:
"""Restructure qna pool json data to dataframe."""
qna_pool_dict = defaultdict(list)

Expand Down Expand Up @@ -144,9 +146,9 @@ def _restructure_qna_pool_json(self, provider_model_id):

return DataFrame.from_dict(qna_pool_dict)

def _get_inscope_qna(self, provider_model_id):
def _get_inscope_qna(self, provider_model_id: str) -> DataFrame:
"""Get QnAs which are inscope for evaluation."""
qna_pool_df = self._restructure_qna_pool_json(provider_model_id)
qna_pool_df: DataFrame = self._restructure_qna_pool_json(provider_model_id)

qna_pool_df = concat([qna_pool_df, self._qa_pool_df])

Expand All @@ -157,15 +159,15 @@ def _get_inscope_qna(self, provider_model_id):
qna_pool_df = qna_pool_df[qna_pool_df.in_use]
return qna_pool_df.reset_index(drop=True).drop(columns="in_use")

def _get_api_response(
def _get_api_response( # pylint: disable=R0913,R0917
self,
question,
provider,
model,
eval_mode,
retry_attempts=MAX_RETRY_ATTEMPTS,
time_to_breath=TIME_TO_BREATH,
): # pylint: disable=R0913,R0917
question: str,
provider: str,
model: str,
eval_mode: str,
retry_attempts: int = MAX_RETRY_ATTEMPTS,
time_to_breath: int = TIME_TO_BREATH,
) -> str:
"""Get api response for a question/query."""
response = None
# try to retrieve response even when model is not responding reliably
Expand Down Expand Up @@ -194,9 +196,14 @@ def _get_api_response(
)
return response

def _get_recent_response(
self, question, recent_resp_df, provider, model, eval_mode
): # pylint: disable=R0913,R0917
def _get_recent_response( # pylint: disable=R0913,R0917
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is disabling pylint checks still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes..

  • either we need to change threshold for arguments in pylint setting.
  • or restructure the code/logic to reduce number of parameters

self,
question: str,
recent_resp_df: DataFrame,
provider: str,
model: str,
eval_mode: str,
) -> str:
"""Get llm response from the stored data, if available."""
if recent_resp_df is not None:
try:
Expand All @@ -211,7 +218,9 @@ def _get_recent_response(
# Recent response is not found, call api to get response
return self._get_api_response(question, provider, model, eval_mode)

def _get_model_response(self, qna_pool_df, provider_model_id, eval_mode):
def _get_model_response(
self, qna_pool_df: DataFrame, provider_model_id: str, eval_mode: str
) -> DataFrame:
"""Get model responses for all questions."""
temp_resp_file = (
f"{self._result_dir}/{eval_mode}_"
Expand Down Expand Up @@ -239,7 +248,7 @@ def _get_model_response(self, qna_pool_df, provider_model_id, eval_mode):
qna_pool_df.to_csv(temp_resp_file, index=False)
return qna_pool_df

def _get_evaluation_score(self, qna_pool_df):
def _get_evaluation_score(self, qna_pool_df: DataFrame) -> DataFrame:
"""Get response evaluation score."""
print("Getting evaluation scores...")
# Default scores
Expand All @@ -264,7 +273,7 @@ def _get_evaluation_score(self, qna_pool_df):
)
return qna_pool_df.dropna(axis=1, how="all")

def _get_response_with_score(self):
def _get_response_with_score(self) -> DataFrame:
"""Get responses with scores."""
result_dfs = []
for provider_model_id in self._args.eval_provider_model_id:
Expand Down Expand Up @@ -294,7 +303,7 @@ def _get_response_with_score(self):
return concat(result_dfs)

@staticmethod
def _condense_eval_df(result_df):
def _condense_eval_df(result_df: DataFrame) -> DataFrame:
"""Put all models' result as columns."""
result_df = result_df.pivot(
index=[
Expand All @@ -311,7 +320,7 @@ def _condense_eval_df(result_df):
result_df.columns = ["_".join(col) for col in result_df.columns]
return result_df

def validate_response(self):
def validate_response(self) -> bool:
"""Validate LLM response."""
consistency_success_flag = True
result_df = self._get_response_with_score()
Expand Down
29 changes: 17 additions & 12 deletions src/lightspeed_core_evaluation/taxonomy_eval.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Taxonomy Answer/Context Evaluation."""

import argparse
import os
import sys
from argparse import ArgumentParser, Namespace
from time import sleep
from typing import Any

import yaml
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from ols import config
from pandas import DataFrame
from pandas import DataFrame, Series
from tqdm import tqdm

from .eval_run_common import add_common_arguments
Expand All @@ -29,9 +30,9 @@
# https://github.com/instructlab/taxonomy/blob/main/knowledge/arts/music/fandom/swifties/qna.yaml


def _args_parser(args):
def _args_parser(args: list[str]) -> Namespace:
"""Arguments parser."""
parser = argparse.ArgumentParser(description="Taxonomy evaluation module.")
parser = ArgumentParser(description="Taxonomy evaluation module.")
# Add arguments common to all eval scripts
add_common_arguments(parser)

Expand Down Expand Up @@ -60,7 +61,7 @@ def _args_parser(args):
class TaxonomyEval: # pylint: disable=R0903
"""Evaluate taxonomy answer/context."""

def __init__(self, eval_args):
def __init__(self, eval_args: Namespace) -> None:
"""Initialize."""
print(f"Arguments: {eval_args}")
self._args = eval_args
Expand Down Expand Up @@ -108,12 +109,14 @@ def _load_taxonomy_yaml(self) -> None:
]
self._taxonomy_df = DataFrame(data_f)

def _get_judge_response(self, question, answer, context, prompt):
def _get_judge_response(
self, question: str, answer: str, context: str, prompt: str
) -> dict[str, Any]:
"""Get Judge response."""
print("Getting Judge response...")
result = None
prompt = PromptTemplate.from_template(prompt)
judge_llm = prompt | self._judge_llm | JsonOutputParser()
llm_prompt = PromptTemplate.from_template(prompt)
judge_llm = llm_prompt | self._judge_llm | JsonOutputParser()

for retry_counter in range(MAX_RETRY_ATTEMPTS):
try:
Expand All @@ -134,7 +137,7 @@ def _get_judge_response(self, question, answer, context, prompt):

return result

def _get_score(self, df, scores, prompt):
def _get_score(self, df: DataFrame, scores: list[str], prompt: str) -> DataFrame:
"""Get score."""
df["score"] = df.progress_apply(
lambda row: self._get_judge_response(
Expand All @@ -147,7 +150,7 @@ def _get_score(self, df, scores, prompt):
df[s] = df["score"].apply(lambda x: x.get(s, None)) # pylint: disable=W0640
return df

def _get_custom_score(self):
def _get_custom_score(self) -> DataFrame:
"""Get custom score."""
df = self._taxonomy_df.copy()
if self._args.eval_type in ("all", "context"):
Expand All @@ -163,7 +166,7 @@ def _get_custom_score(self):
df.drop(columns=["score"], inplace=True)
return df

def _get_ragas_score(self):
def _get_ragas_score(self) -> DataFrame:
"""Get ragas score."""
# pylint: disable=C0415
from ragas import SingleTurnSample
Expand All @@ -172,7 +175,9 @@ def _get_ragas_score(self):

judge_llm = LangchainLLMWrapper(self._judge_llm)

def _get_score(data, scorer):
def _get_score(
data: Series, scorer: LLMContextPrecisionWithoutReference | Faithfulness
) -> float:
data = SingleTurnSample(
user_input=data.question,
response=data.answer,
Expand Down
Loading
Loading