From 26ce4fa3e760f5f40d67e2b75fda19004ac44183 Mon Sep 17 00:00:00 2001 From: bugtig6351 Date: Sun, 24 Mar 2024 22:42:08 +0800 Subject: [PATCH] tmp --- benchmarks/ASQA/asqa_benchmark.py | 10 ++-- benchmarks/ASQA/generate.py | 4 +- benchmarks/base.py | 91 ++++++++++++++++++++++++------- 3 files changed, 80 insertions(+), 25 deletions(-) diff --git a/benchmarks/ASQA/asqa_benchmark.py b/benchmarks/ASQA/asqa_benchmark.py index 49407fb..ebd3be9 100644 --- a/benchmarks/ASQA/asqa_benchmark.py +++ b/benchmarks/ASQA/asqa_benchmark.py @@ -5,6 +5,7 @@ import logging import argparse from benchmarks import BaseBenchmark +from rageval.metrics import (AnswerRougeCorrectness, AnswerEMCorrectness, AnswerDisambigF1Correctness) logger = logging.getLogger(__name__) @@ -12,25 +13,26 @@ class ASQABenchmark(BaseBenchmark): name = "asqa_benchmark" - metrics = ["AnswerRougeCorrectness", "AnswerEMCorrectness", "AnswerDisambigF1Correctness"] + metrics = [AnswerRougeCorrectness(rouge_type="rougeL"), + AnswerEMCorrectness(ignore_case=True), + AnswerDisambigF1Correctness()] def __init__(self, output_dir: str, batch_size: int = 1) -> None: self.output_dir = output_dir self.batch_size = batch_size - def load_data(self, **kwargs) -> Dataset: + def load_data(self, **kwargs): """Load ASQA dataset. For the ASQA dataset, the `short_answers` and `long_answers` are stored in the "qa_pairs" and "annotations" columns, respectively. We need to extract them and add them to the dataset. """ print("Load ASQA dataset...") - self.dataset = load_dataset(**kwargs) + super().load_data(**kwargs) if "short_answers" not in dataset.features: self.dataset = self.dataset.map(lambda example: {"short_answers": [ann["short_answers"] for ann in example["qa_pairs"]]}) if "long_answers" not in dataset.features: self.dataset = self.dataset.map(lambda example: {"long_answers": [ann["long_answer"] for ann in example["annotations"]]}) print("ASQA dataset loaded.") - return self.dataset def evaluate(self, dataset_name:str = "result_dataset", result_name:str = "results") -> Dataset: """Evaluate the dataset and return the dataset with scores. diff --git a/benchmarks/ASQA/generate.py b/benchmarks/ASQA/generate.py index 01c0533..770a417 100644 --- a/benchmarks/ASQA/generate.py +++ b/benchmarks/ASQA/generate.py @@ -100,7 +100,7 @@ def extract_key_information(pred: str) -> str: pred = re.sub(r'\(\d+\)\s', '', pred) # remove the index numbers return pred -def generete_answers(engine: InstructGPT, dataset: Dataset) -> Dataset: +def generate_answers(engine: InstructGPT, dataset: Dataset) -> Dataset: prompts = [ PROMPT.format(few_shot_examples=FEW_SHOT_EXAMPLES, question=data['ambiguous_question']) @@ -136,7 +136,7 @@ def generete_answers(engine: InstructGPT, dataset: Dataset) -> Dataset: max_tokens=args.max_new_tokens) print("Start generate answers...") - dataset = generete_answers(engine, dataset) + dataset = generate_answers(engine, dataset) dataset.to_json(f"{args.output_dir}/{args.dataset_name}.jsonl") print(f"\nFinish generate dataset. Dataset saved as {args.output_dir}/{args.dataset_name}.jsonl") diff --git a/benchmarks/base.py b/benchmarks/base.py index 98926d8..4eeed37 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -1,14 +1,15 @@ -from typing import List, Union +from typing import List, Union, Dict, Any, Tuple, Optional from abc import abstractmethod, ABC from dataclasses import dataclass -import importlib -from datasets import Dataset -from rageval.metrics import Metric, MetricWithLLM +# import importlib +from datasets import Dataset, load_dataset +from rageval.metrics import Metric class BaseBenchmark(ABC): """Base class for benchmarks.""" - metrics: List[str] + metrics: List[Metric] = [] + dataset: Dataset def __init__() -> None: """Initialization.""" @@ -20,22 +21,74 @@ def name(self) -> str: """The benchmark name.""" ... + @property + def metric_names(self) -> List[str]: + """The metric names.""" + return [m.name for m in self.metrics] + @abstractmethod - def load_data(self,) -> Dataset: + def load_data(self, **kwargs) -> None: """Load the dataset with answers to evaluate.""" - ... + self.dataset = load_dataset(**kwargs) @abstractmethod - def evaluate(self,) -> Dataset: - """Evaluate the dataset and return the dataset with scores.""" + def _evaluate(self) -> Tuple[Dict[Any], Dataset]: + """Evaluate the dataset and return the results and the detailed dataset with each sample scores.""" ... - - # @abstractmethod - # def save_result(self,) -> None: - # """Save the result to files.""" - # ... - - def get_metric(self, name: str, **kwargs) -> Union[Metric, MetricWithLLM]: - """Get the metric by name.""" - module = importlib.import_module(f"rageval.metrics") - return getattr(module, name)(**kwargs) + + def prepare_data(self, input_column: str, label_column: Optional[str], **kwargs) -> Dataset: + """Prepare the dataset for different metric. + + Args: + input_column: The column name of the input text that has already existed in self.dataset, e.g. `long_answer`. + label_column: The column name of the label text that the metric requires, e.g. `gt_answer`. + """ + if input_column not in self.dataset.column_names: + raise ValueError(f"The input column {input_column} is not in the dataset. Please check the column names.") + + if not label_column: + return self.dataset + else: + return self.dataset.add_column(label_column, self.dataset[input_column]) + + def cal(metric: Metric, dataset: Dataset, batch_size: int = None) -> Tuple[float, Dataset]: + """Calculate the metric score.""" + metric= { + "name": "AnswerRougeCorrectness", + "rouge_type": "rougeL", + "column": "long_answer" + } + + score, ds = metric.compute(dataset, batch_size) + + def evaluate(self, **kwargs) -> Dict[Any]: + """Load datasets and evaluate it, return a result dict.""" + self.load_data(**kwargs) + self.results, self.dataset = self._evaluate() + return self.results + + def set_metric(self, metrics: Union[List[str], List[Metric]]) -> None: + """Reset the metrics.""" + if all(isinstance(m, Metric) for m in metrics): + self.metrics = metrics + else: + raise ValueError("The metrics should be a list of Metric objects.") + + def save_dataset(self, file_path: str) -> None: + """Save the result to files.""" + if not hasattr(self, "dataset"): + raise ValueError("Please load the dataset and evaluate it first.") + self.dataset.to_json(file_path, orient="records") + print(f"Dataset saved to {file_path}.") + + def save_results(self, file_path: str) -> None: + """Save the result to files.""" + if not hasattr(self, "results"): + raise ValueError("Please run evaluation first.") + self.results.to_json(file_path, orient="records") + print(f"Results saved to {file_path}.") + + # def get_metric(self, name: str, **kwargs) -> Union[Metric, MetricWithLLM]: + # """Get the metric by name.""" + # module = importlib.import_module(f"rageval.metrics") + # return getattr(module, name)(**kwargs)