Skip to content

Add way to load local datafile #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 82 additions & 24 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
)
from lighteval.utils.utils import ListLike, as_list, download_dataset_worker


if TYPE_CHECKING:
from lighteval.logging.evaluation_tracker import EvaluationTracker

Expand Down Expand Up @@ -97,7 +96,10 @@ class LightevalTaskConfig:
# Additional hf dataset config
hf_revision: Optional[str] = None
hf_filter: Optional[Callable[[dict], bool]] = None
hf_avail_splits: Optional[ListLike[str]] = field(default_factory=lambda: ["train", "validation", "test"])
hf_avail_splits: Optional[ListLike[str]] = field(
default_factory=lambda: ["train", "validation", "test"]
)
hf_data_files: Optional[str] = None
# We default to false, to reduce security issues
trust_dataset: bool = False

Expand All @@ -123,14 +125,21 @@ class LightevalTaskConfig:

def __post_init__(self):
# If we got a Metrics enums instead of a Metric, we convert
self.metric = [metric.value if isinstance(metric, Metrics) else metric for metric in self.metric]
self.metric = [
metric.value if isinstance(metric, Metrics) else metric
for metric in self.metric
]

# Convert list to tuple for hashing
self.metric = tuple(self.metric)
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
self.hf_avail_splits = (
tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
)
self.evaluation_splits = tuple(self.evaluation_splits)
self.suite = tuple(self.suite)
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else ()
self.stop_sequence = (
tuple(self.stop_sequence) if self.stop_sequence is not None else ()
)

def print(self):
md_writer = MarkdownTableWriter()
Expand All @@ -143,7 +152,9 @@ def print(self):
for ix, metrics in enumerate(v):
for metric_k, metric_v in metrics.items():
if inspect.ismethod(metric_v):
values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__])
values.append(
[f"{k} {ix}: {metric_k}", metric_v.__qualname__]
)
else:
values.append([f"{k} {ix}: {metric_k}", repr(metric_v)])

Expand Down Expand Up @@ -182,6 +193,7 @@ def __init__( # noqa: C901
self.dataset_config_name = cfg.hf_subset
self.dataset_revision = cfg.hf_revision
self.dataset_filter = cfg.hf_filter
self.dataset_files = cfg.hf_data_files
self.trust_dataset = cfg.trust_dataset
self.dataset: Optional[DatasetDict] = None # Delayed download
logger.info(f"{self.dataset_path} {self.dataset_config_name}")
Expand All @@ -194,19 +206,29 @@ def __init__( # noqa: C901
if cfg.few_shots_split is not None:
self.fewshot_split = as_list(cfg.few_shots_split)
else:
self.fewshot_split = self.get_first_possible_fewshot_splits(cfg.hf_avail_splits or [])
self.fewshot_split = self.get_first_possible_fewshot_splits(
cfg.hf_avail_splits or []
)
self.fewshot_selection = cfg.few_shots_select

# Metrics
self.metrics = as_list(cfg.metric)
self.suite = as_list(cfg.suite)
ignored = [metric for metric in self.metrics if metric.category == MetricCategory.IGNORED]
ignored = [
metric
for metric in self.metrics
if metric.category == MetricCategory.IGNORED
]

if len(ignored) > 0:
logger.warning(f"Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.")
logger.warning(
f"Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}."
)

current_categories = [metric.category for metric in self.metrics]
self.has_metric_category = {category: (category in current_categories) for category in MetricCategory}
self.has_metric_category = {
category: (category in current_categories) for category in MetricCategory
}

# We assume num_samples always contains 1 (for base generative evals)
self.num_samples = [1]
Expand Down Expand Up @@ -244,20 +266,26 @@ def get_first_possible_fewshot_splits(
list[str]: List of the first available fewshot splits.
"""
# Possible few shot splits are the available splits not used for evaluation
possible_fewshot_splits = [k for k in available_splits if k not in self.evaluation_split]
possible_fewshot_splits = [
k for k in available_splits if k not in self.evaluation_split
]
stored_splits = []

# We look at these keys in order (first the training sets, then the validation sets)
allowed_splits = ["train", "dev", "valid", "default"]
for allowed_split in allowed_splits:
# We do a partial match of the allowed splits
available_splits = [k for k in possible_fewshot_splits if allowed_split in k]
available_splits = [
k for k in possible_fewshot_splits if allowed_split in k
]
stored_splits.extend(available_splits)

if len(stored_splits) > 0:
return stored_splits[:number_of_splits]

logger.warning(f"Careful, the task {self.name} is using evaluation data to build the few shot examples.")
logger.warning(
f"Careful, the task {self.name} is using evaluation data to build the few shot examples."
)
return None

def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
Expand All @@ -279,6 +307,7 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
self.trust_dataset,
self.dataset_filter,
self.dataset_revision,
self.dataset_files,
)
splits = as_list(splits)

Expand Down Expand Up @@ -319,9 +348,13 @@ def fewshot_docs(self) -> list[Doc]:

# If we have no available few shot split, the few shot data is the eval data!
if self.fewshot_split is None:
self._fewshot_docs = self._get_docs_from_split(self.evaluation_split, few_shots=True)
self._fewshot_docs = self._get_docs_from_split(
self.evaluation_split, few_shots=True
)
else: # Normal case
self._fewshot_docs = self._get_docs_from_split(self.fewshot_split, few_shots=True)
self._fewshot_docs = self._get_docs_from_split(
self.fewshot_split, few_shots=True
)
return self._fewshot_docs

def eval_docs(self) -> list[Doc]:
Expand All @@ -338,7 +371,11 @@ def eval_docs(self) -> list[Doc]:
return self._docs

def construct_requests(
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str
self,
formatted_doc: Doc,
context: str,
document_id_seed: str,
current_task_name: str,
) -> Dict[RequestType, List[Request]]:
"""
Constructs a list of requests from the task based on the given parameters.
Expand Down Expand Up @@ -435,7 +472,10 @@ def construct_requests(
choice=choice,
metric_categories=[
c
for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI]
for c in [
MetricCategory.MULTICHOICE,
MetricCategory.MULTICHOICE_PMI,
]
if self.has_metric_category[c]
],
)
Expand Down Expand Up @@ -499,15 +539,20 @@ def construct_requests(

def get_metric_method_from_category(self, metric_category):
if not self.has_metric_category[metric_category]:
raise ValueError(f"Requested a metric category {metric_category} absent from the task list.")
raise ValueError(
f"Requested a metric category {metric_category} absent from the task list."
)

return LightevalTask._get_metric_method_from_category(metric_category)

@staticmethod
def _get_metric_method_from_category(metric_category):
if metric_category == MetricCategory.TARGET_PERPLEXITY:
return apply_target_perplexity_metric
if metric_category in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI]:
if metric_category in [
MetricCategory.MULTICHOICE,
MetricCategory.MULTICHOICE_PMI,
]:
return apply_multichoice_metric
if metric_category == MetricCategory.MULTICHOICE_ONE_TOKEN:
return apply_multichoice_metric_one_token
Expand All @@ -519,7 +564,10 @@ def _get_metric_method_from_category(metric_category):
MetricCategory.GENERATIVE_LOGPROB,
]:
return apply_generative_metric
if metric_category in [MetricCategory.LLM_AS_JUDGE_MULTI_TURN, MetricCategory.LLM_AS_JUDGE]:
if metric_category in [
MetricCategory.LLM_AS_JUDGE_MULTI_TURN,
MetricCategory.LLM_AS_JUDGE,
]:
return apply_llm_as_judge_metric

def aggregation(self):
Expand All @@ -530,7 +578,9 @@ def aggregation(self):
return Metrics.corpus_level_fns(self.metrics)

@staticmethod
def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int = 1) -> None:
def load_datasets(
tasks: list["LightevalTask"], dataset_loading_processes: int = 1
) -> None:
"""
Load datasets from the HuggingFace Hub for the given tasks.

Expand All @@ -550,6 +600,7 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =
task.trust_dataset,
task.dataset_filter,
task.dataset_revision,
task.dataset_files,
)
for task in tasks
]
Expand All @@ -564,6 +615,7 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =
task.trust_dataset,
task.dataset_filter,
task.dataset_revision,
task.dataset_files,
)
for task in tasks
],
Expand Down Expand Up @@ -615,13 +667,17 @@ def create_requests_from_tasks( # noqa: C901
requests: dict[RequestType, list[Request]] = collections.defaultdict(list)

# Filter out tasks that don't have any docs
task_dict_items = [(name, task) for name, task in task_dict.items() if len(task.eval_docs()) > 0]
task_dict_items = [
(name, task) for name, task in task_dict.items() if len(task.eval_docs()) > 0
]

# Get lists of each type of request
for task_name, task in task_dict_items:
task_docs = list(task.eval_docs())
n_samples = min(max_samples, len(task_docs)) if max_samples else len(task_docs)
evaluation_tracker.task_config_logger.log_num_docs(task_name, len(task_docs), n_samples)
evaluation_tracker.task_config_logger.log_num_docs(
task_name, len(task_docs), n_samples
)

# logs out the different versions of the tasks for every few shot
for num_fewshot, _ in fewshot_dict[task_name]:
Expand Down Expand Up @@ -655,7 +711,9 @@ def create_requests_from_tasks( # noqa: C901
# Constructing the requests
cur_task_name = f"{task_name}|{num_fewshot}"
docs[SampleUid(cur_task_name, doc_id_seed)] = doc
req_type_reqs_dict = task.construct_requests(doc, doc.ctx, doc_id_seed, cur_task_name)
req_type_reqs_dict = task.construct_requests(
doc, doc.ctx, doc_id_seed, cur_task_name
)
for req_type, reqs in req_type_reqs_dict.items():
requests[req_type].extend(reqs)

Expand Down
28 changes: 22 additions & 6 deletions src/lighteval/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def flatten_dict(nested: dict, sep="/") -> dict:
"""Flatten dictionary, list, tuple and concatenate nested keys with separator."""

def clean_markdown(v: str) -> str:
return v.replace("|", "_").replace("\n", "_") if isinstance(v, str) else v # Need this for markdown
return (
v.replace("|", "_").replace("\n", "_") if isinstance(v, str) else v
) # Need this for markdown

def rec(nest: dict, prefix: str, into: dict):
for k, v in sorted(nest.items()):
Expand All @@ -37,9 +39,13 @@ def rec(nest: dict, prefix: str, into: dict):
rec(vv, prefix + k + sep + str(i) + sep, into)
else:
vv = (
vv.replace("|", "_").replace("\n", "_") if isinstance(vv, str) else vv
vv.replace("|", "_").replace("\n", "_")
if isinstance(vv, str)
else vv
) # Need this for markdown
into[prefix + k + sep + str(i)] = vv.tolist() if isinstance(vv, np.ndarray) else vv
into[prefix + k + sep + str(i)] = (
vv.tolist() if isinstance(vv, np.ndarray) else vv
)
elif isinstance(v, np.ndarray):
into[prefix + k + sep + str(i)] = v.tolist()
else:
Expand All @@ -63,7 +69,9 @@ def clean_s3_links(value: str) -> str:
s3_bucket, s3_prefix = str(value).replace("s3://", "").split("/", maxsplit=1)
if not s3_prefix.endswith("/"):
s3_prefix += "/"
link_str = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}"
link_str = (
f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}"
)
value = f'<a href="{link_str}" target="_blank"> {value} </a>'
return value

Expand Down Expand Up @@ -151,7 +159,11 @@ def flatten(item: list[Union[list, str]]) -> list[str]:
"""
flat_item = []
for sub_item in item:
flat_item.extend(sub_item) if isinstance(sub_item, list) else flat_item.append(sub_item)
(
flat_item.extend(sub_item)
if isinstance(sub_item, list)
else flat_item.append(sub_item)
)
return flat_item


Expand Down Expand Up @@ -205,6 +217,7 @@ def download_dataset_worker(
trust_dataset: bool,
dataset_filter: Callable[[dict], bool] | None = None,
revision: str | None = None,
data_files: str | None = None,
) -> DatasetDict:
"""
Worker function to download a dataset from the HuggingFace Hub.
Expand All @@ -218,6 +231,7 @@ def download_dataset_worker(
download_mode=None,
trust_remote_code=trust_dataset,
revision=revision,
data_files=data_files,
)

if dataset_filter is not None:
Expand All @@ -227,5 +241,7 @@ def download_dataset_worker(
return dataset # type: ignore


def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray:
def safe_divide(
numerator: np.ndarray, denominator: float, default_value: float = 0.0
) -> np.ndarray:
return np.where(denominator != 0, numerator / denominator, default_value)
Loading