Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/cookbook/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,65 @@
],
}

DEEPMIND_MATH_CATEGORIES = [
"algebra__linear_1d",
"algebra__linear_1d_composed",
"algebra__linear_2d",
"algebra__linear_2d_composed",
"algebra__polynomial_roots",
"algebra__polynomial_roots_composed",
"algebra__sequence_next_term",
"algebra__sequence_nth_term",
"arithmetic__add_or_sub",
"arithmetic__add_or_sub_in_base",
"arithmetic__add_sub_multiple",
"arithmetic__div",
"arithmetic__mixed",
"arithmetic__mul",
"arithmetic__mul_div_multiple",
"arithmetic__nearest_integer_root",
"arithmetic__simplify_surd",
"calculus__differentiate",
"calculus__differentiate_composed",
"comparison__closest",
"comparison__closest_composed",
"comparison__kth_biggest",
"comparison__kth_biggest_composed",
"comparison__pair",
"comparison__pair_composed",
"comparison__sort",
"comparison__sort_composed",
"measurement__conversion",
"measurement__time",
"numbers__base_conversion",
"numbers__div_remainder",
"numbers__div_remainder_composed",
"numbers__gcd",
"numbers__gcd_composed",
"numbers__is_factor",
"numbers__is_factor_composed",
"numbers__is_prime",
"numbers__is_prime_composed",
"numbers__lcm",
"numbers__lcm_composed",
"numbers__list_prime_factors",
"numbers__list_prime_factors_composed",
"numbers__place_value",
"numbers__place_value_composed",
"numbers__round_number",
"numbers__round_number_composed",
"polynomials__add",
"polynomials__coefficient_named",
"polynomials__collect",
"polynomials__compose",
"polynomials__evaluate",
"polynomials__evaluate_composed",
"polynomials__expand",
"polynomials__simplify_power",
"probability__swr_p_level_set",
"probability__swr_p_sequence",
]


OE_EVAL_GIT_URL = "git@github.com:allenai/oe-eval-internal.git"
OE_EVAL_COMMIT_HASH = None
Expand Down
125 changes: 125 additions & 0 deletions src/cookbook/eval/named_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ class MinervaMidtrainReasoningGroup(BaseAverageNamedTasksGroup):
tasks = [f"{subtask}::olmo3:midtrain" for subtask in constants.ALL_MINERVA_TASKS]


@NamedTasksGroupRegistry.register("deepmind_math::olmo3:heldout")
class DeepmindMathHeldoutGroup(BaseAverageNamedTasksGroup):
tasks = [f"deepmind_math_{cat}::olmo3:heldout" for cat in constants.DEEPMIND_MATH_CATEGORIES]


@NamedTasksGroupRegistry.register("math")
class MathGroup(BaseAverageOfAveragesNamedTasksGroup):
tasks = [
Expand Down Expand Up @@ -475,6 +480,11 @@ class BBHMidtrainThinkerGroup(BaseAverageNamedTasksGroup):
tasks = [f"bbh_{category}:cot::olmo3:midtrain" for category in constants.BBH_TASKS]


@NamedTasksGroupRegistry.register("bbh:cot::olmo3:heldout")
class BBHHeldoutGroup(BaseAverageNamedTasksGroup):
tasks = [f"bbh_{category}:cot::olmo3:heldout" for category in constants.BBH_TASKS]


@NamedTasksGroupRegistry.register("ifeval_mt::tulu-thinker")
class IFEvalMTThinkerGroup(BaseAverageNamedTasksGroup):
tasks = [f"ifeval_mt_{task_type}::tulu-thinker" for task_type in constants.IFEVAL_MT_TASKS]
Expand Down Expand Up @@ -615,6 +625,37 @@ class Olmo3Dev1bQaBpbGroup(BaseAverageOfAveragesNamedTasksGroup):
"sciriff_yesno:bpb::olmes",
]

@NamedTasksGroupRegistry.register("olmo3:dev:1b:qa:bpb:v2")
class Olmo3Dev1bQaBpbV2Group(BaseAverageOfAveragesNamedTasksGroup):
tasks = [
# Core OLMES
ARCBPBFullGroup(),
MMLUBpbGroup(),
"csqa:bpb::olmes:full",
"hellaswag:bpb::olmes:full",
"winogrande:bpb::olmes:full",
"socialiqa:bpb::olmes:full",
"piqa:bpb::olmes:full",

# Gen OLMES
"coqa:bpb::gen2mc:xlarge",
"drop:bpb::gen2mc:xlarge",
"jeopardy:bpb::gen2mc:xlarge",
"naturalqs:bpb::gen2mc:xlarge",
"squad:bpb::gen2mc:xlarge",

# New OLMo 3
"sciq:bpb::olmo3",
"qasper_yesno:bpb::olmes",
BasicBpbGroup(),
"lab_bench_dbqa:bpb",
"lab_bench_protocolqa:bpb",
"lambada:bpb",
"medmcqa:bpb::none",
"medqa_en:bpb::none",
"sciriff_yesno:bpb::olmes",
]


@NamedTasksGroupRegistry.register("olmo3:dev:1b:qa:rc")
class Olmo3Dev1bQaRcGroup(BaseAverageOfAveragesNamedTasksGroup):
Expand Down Expand Up @@ -648,6 +689,38 @@ class Olmo3Dev1bQaRcGroup(BaseAverageOfAveragesNamedTasksGroup):
]


@NamedTasksGroupRegistry.register("olmo3:dev:1b:qa:rc:v2")
class Olmo3Dev1bQaRcV2Group(BaseAverageOfAveragesNamedTasksGroup):
tasks = [
# Core OLMES
ARCRCFullGroup(),
MMLURCGroup(),
"csqa:rc::olmes:full",
"hellaswag:rc::olmes:full",
"winogrande:rc::olmes:full",
"socialiqa:rc::olmes:full",
"piqa:rc::olmes:full",

# Gen OLMES
"coqa:rc::gen2mc:xlarge",
"drop:rc::gen2mc:xlarge",
"jeopardy:rc::gen2mc:xlarge",
"naturalqs:rc::gen2mc:xlarge",
"squad:rc::gen2mc:xlarge",

# New OLMo 3
"sciq:rc::olmo3",
"qasper_yesno:rc::olmes",
BasicRCGroup(),
"lab_bench_dbqa",
"lab_bench_protocolqa",
"lambada",
"medmcqa:rc::none",
"medqa_en:rc::none",
"sciriff_yesno:rc::olmes",
]


@NamedTasksGroupRegistry.register("olmo3:dev:1b:bpb")
class Olmo3Dev1bBpbGroup(BaseAverageOfAveragesNamedTasksGroup):
tasks = [
Expand Down Expand Up @@ -749,6 +822,23 @@ class Olmo3Dev7bMcqaNonSTEMGroup(BaseAverageOfAveragesNamedTasksGroup):
]


@NamedTasksGroupRegistry.register("olmo3:dev:7b:mcqa:non_stem:v2")
class Olmo3Dev7bMcqaNonSTEMV2Group(BaseAverageOfAveragesNamedTasksGroup):
tasks = [
MMLUHumanitiesMCGroup(),
MMLUSocialSciencesMCGroup(),
MMLUOtherMCGroup(),
"csqa:mc::xlarge",
"piqa:mc::xlarge",
"socialiqa:mc::xlarge",
"coqa:mc::gen2mc:xlarge",
"drop:mc::gen2mc:xlarge",
"jeopardy:mc::gen2mc:xlarge",
"naturalqs:mc::gen2mc:xlarge",
"squad:mc::gen2mc:xlarge",
]


# # # # # # # # # # # # # # # DISPLAY TASK GROUPS # # # # # # # # # # # # # # # # #
# These are just shortcuts to display many metrics at once. no need to average. #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
Expand Down Expand Up @@ -899,4 +989,39 @@ class Olmo3DevMidtrainV2MainGroup(BaseNamedTasksWithNoAverageGroup):
"popqa::olmo3:midtrain",
AgiEvalEnglishMidtrainGroup(),
MMLUMidtrainGroup(),
]


@NamedTasksGroupRegistry.register("olmo3:base_heldout")
class Olmo3BaseHeldoutGroup(BaseNamedTasksWithNoAverageGroup):
tasks = [
BBHHeldoutGroup(),
MMLUProMCGroup(),
DeepmindMathHeldoutGroup(),
"lbpp::olmo3",
]


@NamedTasksGroupRegistry.register("olmo3:paper")
class Olmo3PaperGroup(BaseNamedTasksWithNoAverageGroup):
tasks = [
# olmo3:base_easy
Olmo3Dev1bMathBpbGroup(),
Olmo3Dev1bCodeBpbGroup(),
Olmo3Dev1bQaBpbV2Group(),
Olmo3Dev1bQaRcV2Group(),

# olmo3:base
Olmo3Dev7bMcqaSTEMGroup(),
Olmo3Dev7bMcqaNonSTEMV2Group(),
Olmo3Dev7bGenGroup(),
Olmo3Dev7bMathV2Group(),
Olmo3Dev7bCodeGenV2Group(),
Olmo3Dev7bCodeFimGroup(),

# olmo3:base_chat
Olmo3DevMidtrainV2MainGroup(),

# olmo3:heldout
Olmo3BaseHeldoutGroup(),
]