Skip to content

Commit c6d1231

Browse files
hynky1999Hynek Kydlicekclefourrier
authored
Adds More Generative tasks (#694)
* add smolm generative tasks * add jeopardy * pretty 🥰 * consistent stop sequences * add versions + change names --------- Co-authored-by: Hynek Kydlicek <kydlicek.hynek@huggingface.co> Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com>
1 parent f684d35 commit c6d1231

File tree

3 files changed

+130
-30
lines changed

3 files changed

+130
-30
lines changed

src/lighteval/metrics/dynamic_metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None)
6161
Creates an accuracy (loglikelihood) metric, which returns accuracy given normalization.
6262
"""
6363

64-
normalization_str = normalization.name if normalization else ""
65-
metric_name = f"acc_{normalization_str}"
64+
normalization_str = f"_{normalization.name}" if normalization else ""
65+
metric_name = f"acc{normalization_str}"
6666
return SampleLevelMetric(
6767
metric_name=metric_name,
6868
sample_level_fn=LoglikelihoodAcc(logprob_normalization=normalization).compute,
@@ -83,8 +83,8 @@ def normalized_multi_choice_prob_metric(
8383
Creates a normalized multi-choice probability metric, which returns the probability of the gold choice / sum of probabilities of all choices (after logprobs are normalized).
8484
"""
8585

86-
normalization_str = normalization.name if normalization else ""
87-
metric_name = "_".join(filter(None, ["normalized_mc_prob_", normalization_str]))
86+
normalization_str = f"_{normalization.name}" if normalization else ""
87+
metric_name = f"normalized_mc_prob{normalization_str}"
8888

8989
return SampleLevelMetric(
9090
metric_name=metric_name,
@@ -108,8 +108,8 @@ def probability_metric(
108108
Creates a probability metric, which returns the probability of the gold choice given normalization.
109109
"""
110110

111-
normalization_str = normalization.name if normalization else ""
112-
metric_name = "_".join(filter(None, ["prob", normalization_str]))
111+
normalization_str = f"_{normalization.name}" if normalization else ""
112+
metric_name = f"prob{normalization_str}"
113113

114114
return SampleLevelMetric(
115115
metric_name=metric_name,
@@ -188,7 +188,7 @@ def multilingual_quasi_exact_match_metric(
188188
def multilingual_extractive_match_metric(
189189
language: Language = Language.ENGLISH,
190190
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
191-
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
191+
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(), LatexExtractionConfig()),
192192
aggregation_function: Callable[[list[float]], float] = max,
193193
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
194194
extraction_mode: Literal["first_match", "any_match"] = "any_match",

src/lighteval/tasks/default_prompts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2774,3 +2774,10 @@ def xsum(line, task_name: str = None):
27742774
choices=[str(line["summary"])],
27752775
specific={"text": line["article"]},
27762776
)
2777+
2778+
2779+
# Utility for drop task
2780+
def get_drop_date(x):
2781+
components = [x["day"], x["month"], x["year"]]
2782+
components = list(filter(lambda x: x, components))
2783+
return " ".join(components)

src/lighteval/tasks/default_tasks.py

Lines changed: 116 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import lighteval.tasks.default_prompts as prompt
2323
from lighteval.metrics.metrics import Metrics
2424
from lighteval.tasks.lighteval_task import LightevalTaskConfig
25+
from lighteval.tasks.templates.qa import get_qa_prompt_function
26+
from lighteval.utils.language import Language
2527

2628

2729
abstract_narrative_understanding_bigbench = LightevalTaskConfig(
@@ -6627,21 +6629,28 @@
66276629
trust_dataset=True,
66286630
version=0,
66296631
)
6630-
coqa_lighteval = LightevalTaskConfig(
6632+
coqa_first_question = LightevalTaskConfig(
66316633
name="coqa",
6632-
suite=["lighteval"],
6633-
prompt_function=prompt.coqa,
6634-
hf_repo="coqa",
6634+
prompt_function=get_qa_prompt_function(
6635+
Language.ENGLISH,
6636+
lambda line: {
6637+
"question": line["questions"][0],
6638+
"context": line["story"],
6639+
"choices": [line["answers"]["input_text"][0]],
6640+
},
6641+
),
6642+
suite=("lighteval",),
6643+
hf_repo="stanfordnlp/coqa",
66356644
hf_subset="default",
66366645
hf_avail_splits=["train", "validation"],
66376646
evaluation_splits=["validation"],
6638-
few_shots_split=None,
6639-
few_shots_select=None,
6640-
generation_size=10,
6641-
metric=[Metrics.perfect_exact_match, Metrics.f1_score],
6642-
stop_sequence=["\n"],
6643-
trust_dataset=True,
6644-
version=0,
6647+
stop_sequence=["\n", "Question:", "question:"],
6648+
generation_size=100,
6649+
version=1,
6650+
metric=(
6651+
Metrics.prefix_quasi_exact_match,
6652+
Metrics.f1_score_quasi,
6653+
),
66456654
)
66466655
coqa_bb_lighteval = LightevalTaskConfig(
66476656
name="coqa_bb",
@@ -6835,21 +6844,43 @@
68356844
trust_dataset=True,
68366845
version=0,
68376846
)
6838-
drop_lighteval = LightevalTaskConfig(
6847+
drop_qa = LightevalTaskConfig(
68396848
name="drop",
6840-
suite=["lighteval"],
6841-
prompt_function=prompt.drop,
6849+
prompt_function=get_qa_prompt_function(
6850+
Language.ENGLISH,
6851+
lambda line: {
6852+
"context": line["passage"],
6853+
"question": line["question"],
6854+
"choices": list(
6855+
filter(
6856+
lambda x: x,
6857+
[line["answer"].get("number")]
6858+
+ line["answer"]["spans"]
6859+
+ [prompt.get_drop_date(line["answer"].get("date"))],
6860+
)
6861+
),
6862+
},
6863+
),
6864+
suite=("lighteval",),
68426865
hf_repo="lighteval/drop_harness",
68436866
hf_subset="default",
6844-
hf_avail_splits=["train", "validation"],
6845-
evaluation_splits=["validation"],
6867+
hf_filter=lambda line: list(
6868+
filter(
6869+
lambda x: x,
6870+
[line["answer"].get("number")]
6871+
+ line["answer"]["spans"]
6872+
+ [prompt.get_drop_date(line["answer"].get("date"))],
6873+
)
6874+
),
6875+
evaluation_splits=("validation",),
68466876
few_shots_split="train",
6847-
few_shots_select="random_sampling_from_train",
6848-
generation_size=None,
6849-
metric=[Metrics.drop],
6850-
stop_sequence=["."],
6851-
trust_dataset=True,
6852-
version=0,
6877+
generation_size=250,
6878+
stop_sequence=["Question:", "question:", "\n"],
6879+
metric=(
6880+
Metrics.prefix_quasi_exact_match,
6881+
Metrics.f1_score_quasi,
6882+
),
6883+
version=1,
68536884
)
68546885
dyck_language_2_helm = LightevalTaskConfig(
68556886
name="dyck_language:2",
@@ -8581,6 +8612,27 @@
85818612
trust_dataset=True,
85828613
version=0,
85838614
)
8615+
jeopardy = LightevalTaskConfig(
8616+
name="jeopardy",
8617+
prompt_function=get_qa_prompt_function(
8618+
Language.ENGLISH,
8619+
lambda line: {
8620+
"question": line["question"],
8621+
"choices": [line["answer"]],
8622+
},
8623+
),
8624+
suite=("lighteval",),
8625+
hf_repo="openaccess-ai-collective/jeopardy",
8626+
hf_subset="default",
8627+
evaluation_splits=("train",),
8628+
few_shots_split="train",
8629+
generation_size=250,
8630+
stop_sequence=["\n", "Question:", "question:"],
8631+
metric=(
8632+
Metrics.prefix_quasi_exact_match,
8633+
Metrics.f1_score_quasi,
8634+
),
8635+
)
85848636
kanji_ascii_bigbench = LightevalTaskConfig(
85858637
name="kanji_ascii",
85868638
suite=["bigbench", "bigbench_json"],
@@ -13665,6 +13717,24 @@
1366513717
trust_dataset=True,
1366613718
version=0,
1366713719
)
13720+
natural_questions = LightevalTaskConfig(
13721+
name="natural_questions",
13722+
prompt_function=get_qa_prompt_function(
13723+
Language.ENGLISH,
13724+
lambda line: {"question": line["question"], "choices": [line["answer"]]},
13725+
),
13726+
suite=("lighteval",),
13727+
hf_repo="lighteval/small_natural_questions",
13728+
hf_subset="default",
13729+
evaluation_splits=("test",),
13730+
few_shots_split="few_shot",
13731+
generation_size=250,
13732+
stop_sequence=["\n", "Question:", "question:"],
13733+
metric=(
13734+
Metrics.prefix_quasi_exact_match,
13735+
Metrics.f1_score_quasi,
13736+
),
13737+
)
1366813738
navigate_bigbench = LightevalTaskConfig(
1366913739
name="navigate",
1367013740
suite=["bigbench", "bigbench_json"],
@@ -14885,7 +14955,7 @@
1488514955
hf_subset="default",
1488614956
hf_avail_splits=["test"],
1488714957
evaluation_splits=["test"],
14888-
few_shots_split=None,
14958+
few_shots_split="few_shot",
1488914959
few_shots_select=None,
1489014960
generation_size=2048,
1489114961
metric=[Metrics.simpleqa_judge],
@@ -15074,6 +15144,29 @@
1507415144
trust_dataset=True,
1507515145
version=0,
1507615146
)
15147+
squad_v2 = LightevalTaskConfig(
15148+
name="squad_v2",
15149+
prompt_function=get_qa_prompt_function(
15150+
Language.ENGLISH,
15151+
lambda line: {
15152+
"question": line["question"],
15153+
"context": line["context"],
15154+
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
15155+
},
15156+
),
15157+
suite=("lighteval",),
15158+
hf_repo="rajpurkar/squad_v2",
15159+
hf_subset="squad_v2",
15160+
hf_filter=lambda line: any(ans for ans in line["answers"]["text"] if len(ans) > 0),
15161+
evaluation_splits=("validation",),
15162+
few_shots_split="train",
15163+
stop_sequence=["\n", "Question:", "question:"],
15164+
generation_size=200,
15165+
metric=(
15166+
Metrics.prefix_quasi_exact_match,
15167+
Metrics.f1_score_quasi,
15168+
),
15169+
)
1507715170
storycloze_2016_lighteval = LightevalTaskConfig(
1507815171
name="storycloze:2016",
1507915172
suite=["lighteval", "storycloze"],

0 commit comments

Comments
 (0)