Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
bugtig6351 committed Mar 24, 2024
1 parent eef44f9 commit 26ce4fa
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
10 changes: 6 additions & 4 deletions benchmarks/ASQA/asqa_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,34 @@
import logging
import argparse
from benchmarks import BaseBenchmark
from rageval.metrics import (AnswerRougeCorrectness, AnswerEMCorrectness, AnswerDisambigF1Correctness)


logger = logging.getLogger(__name__)

class ASQABenchmark(BaseBenchmark):

name = "asqa_benchmark"
metrics = ["AnswerRougeCorrectness", "AnswerEMCorrectness", "AnswerDisambigF1Correctness"]
metrics = [AnswerRougeCorrectness(rouge_type="rougeL"),
AnswerEMCorrectness(ignore_case=True),
AnswerDisambigF1Correctness()]

def __init__(self, output_dir: str, batch_size: int = 1) -> None:
self.output_dir = output_dir
self.batch_size = batch_size

def load_data(self, **kwargs) -> Dataset:
def load_data(self, **kwargs):
"""Load ASQA dataset.
For the ASQA dataset, the `short_answers` and `long_answers` are stored in the "qa_pairs" and "annotations" columns, respectively. We need to extract them and add them to the dataset.
"""
print("Load ASQA dataset...")
self.dataset = load_dataset(**kwargs)
super().load_data(**kwargs)
if "short_answers" not in dataset.features:
self.dataset = self.dataset.map(lambda example: {"short_answers": [ann["short_answers"] for ann in example["qa_pairs"]]})
if "long_answers" not in dataset.features:
self.dataset = self.dataset.map(lambda example: {"long_answers": [ann["long_answer"] for ann in example["annotations"]]})
print("ASQA dataset loaded.")
return self.dataset

def evaluate(self, dataset_name:str = "result_dataset", result_name:str = "results") -> Dataset:
"""Evaluate the dataset and return the dataset with scores.
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/ASQA/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def extract_key_information(pred: str) -> str:
pred = re.sub(r'\(\d+\)\s', '', pred) # remove the index numbers
return pred

def generete_answers(engine: InstructGPT, dataset: Dataset) -> Dataset:
def generate_answers(engine: InstructGPT, dataset: Dataset) -> Dataset:
prompts = [
PROMPT.format(few_shot_examples=FEW_SHOT_EXAMPLES,
question=data['ambiguous_question'])
Expand Down Expand Up @@ -136,7 +136,7 @@ def generete_answers(engine: InstructGPT, dataset: Dataset) -> Dataset:
max_tokens=args.max_new_tokens)

print("Start generate answers...")
dataset = generete_answers(engine, dataset)
dataset = generate_answers(engine, dataset)

dataset.to_json(f"{args.output_dir}/{args.dataset_name}.jsonl")
print(f"\nFinish generate dataset. Dataset saved as {args.output_dir}/{args.dataset_name}.jsonl")
Expand Down
91 changes: 72 additions & 19 deletions benchmarks/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import List, Union
from typing import List, Union, Dict, Any, Tuple, Optional
from abc import abstractmethod, ABC
from dataclasses import dataclass
import importlib
from datasets import Dataset
from rageval.metrics import Metric, MetricWithLLM
# import importlib
from datasets import Dataset, load_dataset
from rageval.metrics import Metric

class BaseBenchmark(ABC):
"""Base class for benchmarks."""

metrics: List[str]
metrics: List[Metric] = []
dataset: Dataset

def __init__() -> None:
"""Initialization."""
Expand All @@ -20,22 +21,74 @@ def name(self) -> str:
"""The benchmark name."""
...

@property
def metric_names(self) -> List[str]:
"""The metric names."""
return [m.name for m in self.metrics]

@abstractmethod
def load_data(self,) -> Dataset:
def load_data(self, **kwargs) -> None:
"""Load the dataset with answers to evaluate."""
...
self.dataset = load_dataset(**kwargs)

@abstractmethod
def evaluate(self,) -> Dataset:
"""Evaluate the dataset and return the dataset with scores."""
def _evaluate(self) -> Tuple[Dict[Any], Dataset]:
"""Evaluate the dataset and return the results and the detailed dataset with each sample scores."""
...

# @abstractmethod
# def save_result(self,) -> None:
# """Save the result to files."""
# ...

def get_metric(self, name: str, **kwargs) -> Union[Metric, MetricWithLLM]:
"""Get the metric by name."""
module = importlib.import_module(f"rageval.metrics")
return getattr(module, name)(**kwargs)

def prepare_data(self, input_column: str, label_column: Optional[str], **kwargs) -> Dataset:
"""Prepare the dataset for different metric.
Args:
input_column: The column name of the input text that has already existed in self.dataset, e.g. `long_answer`.
label_column: The column name of the label text that the metric requires, e.g. `gt_answer`.
"""
if input_column not in self.dataset.column_names:
raise ValueError(f"The input column {input_column} is not in the dataset. Please check the column names.")

if not label_column:
return self.dataset
else:
return self.dataset.add_column(label_column, self.dataset[input_column])

def cal(metric: Metric, dataset: Dataset, batch_size: int = None) -> Tuple[float, Dataset]:
"""Calculate the metric score."""
metric= {
"name": "AnswerRougeCorrectness",
"rouge_type": "rougeL",
"column": "long_answer"
}

score, ds = metric.compute(dataset, batch_size)

def evaluate(self, **kwargs) -> Dict[Any]:
"""Load datasets and evaluate it, return a result dict."""
self.load_data(**kwargs)
self.results, self.dataset = self._evaluate()
return self.results

def set_metric(self, metrics: Union[List[str], List[Metric]]) -> None:
"""Reset the metrics."""
if all(isinstance(m, Metric) for m in metrics):
self.metrics = metrics
else:
raise ValueError("The metrics should be a list of Metric objects.")

def save_dataset(self, file_path: str) -> None:
"""Save the result to files."""
if not hasattr(self, "dataset"):
raise ValueError("Please load the dataset and evaluate it first.")
self.dataset.to_json(file_path, orient="records")
print(f"Dataset saved to {file_path}.")

def save_results(self, file_path: str) -> None:
"""Save the result to files."""
if not hasattr(self, "results"):
raise ValueError("Please run evaluation first.")
self.results.to_json(file_path, orient="records")
print(f"Results saved to {file_path}.")

# def get_metric(self, name: str, **kwargs) -> Union[Metric, MetricWithLLM]:
# """Get the metric by name."""
# module = importlib.import_module(f"rageval.metrics")
# return getattr(module, name)(**kwargs)

0 comments on commit 26ce4fa

Please sign in to comment.