Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tasks to replicate Math-shepherd #1052

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open

Add tasks to replicate Math-shepherd #1052

wants to merge 19 commits into from

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Nov 6, 2024

Description

WORK IN PROGRESS

This task Integrates the tasks to replicate:
Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations

Example pipeline:

from datasets import load_dataset

from distilabel.steps.tasks.math_shepherd.generator import MathShepherdGenerator
from distilabel.steps.tasks.math_shepherd.completer import MathShepherdCompleter
from distilabel.steps.tasks.math_shepherd.utils import FormatPRM
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, ExpandColumns

ds_name = "openai/gsm8k"

ds = load_dataset(ds_name, "main", split="test").rename_column("question", "instruction").select(range(3))


with Pipeline(name="Math-Shepherd") as pipe:
    model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
    model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    llm_70B = InferenceEndpointsLLM(
        model_id=model_id_8B,
        tokenizer_id=model_id_8B,
        generation_kwargs={"max_new_tokens": 1024, "temperature": 0.6},
    )
    llm_8B = InferenceEndpointsLLM(
        model_id=model_id_8B,
        tokenizer_id=model_id_8B,
        generation_kwargs={"max_new_tokens": 2048, "temperature": 0.6},
    )

    generator_golden = MathShepherdGenerator(
        name="golden_generator",
        llm=llm_70B,
    )
    generator = MathShepherdGenerator(
        name="generator",
        llm=llm_8B,
        M=5  # Generate 5 sample solutions
    )
    completer = MathShepherdCompleter(
        name="completer",
        llm=llm_8B,
        N=4  # Each solution will be tested with 4 completions during labelling
    )

    combine = CombineOutputs()
    expand = ExpandColumns(
        name="expand_columns",
        columns=["solutions"],
        encoded=True,
    )
    formatter = FormatPRM(name="format_prm")
    [generator_golden, generator] >> combine >> completer >> expand >> formatter


if __name__ == "__main__":
    distiset = pipe.run(use_cache=False, dataset=ds)
    distiset.push_to_hub("plaguss/test_math_shepherd_prm")

A sample dataset can be seen at plaguss/test_math_shepherd_prm

@plaguss plaguss added the enhancement New feature or request label Nov 6, 2024
@plaguss plaguss added this to the 1.5.0 milestone Nov 6, 2024
@plaguss plaguss self-assigned this Nov 6, 2024
Copy link

github-actions bot commented Nov 6, 2024

Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1052/

Copy link

codspeed-hq bot commented Nov 6, 2024

CodSpeed Performance Report

Merging #1052 will not alter performance

Comparing math-shepherd (3ca7b7d) with develop (e830e25)

Summary

✅ 1 untouched benchmarks

@plaguss plaguss marked this pull request as ready for review November 12, 2024 12:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant