Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support math-shepherd-mistral-7b-prm model #9697

Merged

Conversation

Went-Liang
Copy link
Contributor

@Went-Liang Went-Liang commented Oct 25, 2024

FILL IN THE PR DESCRIPTION HERE

Support peiyi9979/math-shepherd-mistral-7b-prm as embedding model.

As mentioned by 9314, the Process-Supervised Reward Model, which provides reward scores for intermediate steps generated by LLMs, can offer more fine-grained optimization for Reinforcement Learning (RL). This will help the community reproduce the OpenAI O1 model. PR 9424 allows any model that adds a pooler method to be used as an embedding model.

Therefore, this PR adds a pooler method to LlamaForCausalLM, introduces a pooling-type named "STEP" and adds a PoolerConfig class to facilitate users to configure the pooler method. In STEP mode, users can use the peiyi9979/math-shepherd-mistral-7b-prm model by setting the pooling-step-tag-id and pooling-returned-token-ids variables. pooling-returned-token-ids represents a list of indices for the vocabulary dimensions to be extracted, such as the token IDs of good_token and bad_token in the math-shepherd-mistral-7b-prm model. When pooling-step-tag-id is not None, it indicates that the score corresponding to the pooling-step-tag-id in the generated sentence should be returned. Otherwise, it returns the scores for all tokens.

The model can be served with:

python -m vllm.entrypoints.openai.api_server \
	--model peiyi9979/math-shepherd-mistral-7b-prm \
	--trust-remote-code \
	--served-model-name math-shepherd-mistral-7b-prm \
	--port 8080 \
	--tensor-parallel-size 1 \
	--enforce-eager \
        --task embedding \
	--pooling-type STEP \
	--pooling-step-tag-id 12902 \
	--pooling-returned-token-ids 648 387

And a test correspond to the example in the huggingface model page is:

import torch
from openai import OpenAI

question = """Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"""
output1 = """Step 1: Janet's ducks lay 16 eggs per day. ки\nStep 2: She eats three for breakfast every morning, so she has 16 - 3 = 13 eggs left. ки\nStep 3: She bakes muffins for her friends every day with four eggs, so she has 13 - 4 = 9 eggs left. ки\nStep 4: She sells the remainder at the farmers' market daily for $2 per fresh duck egg, so she makes 9 * $2 = $18 every day at the farmers' market. The answer is: 18 ки""" # 18 is right
output2 = """Step 1: Janet's ducks lay 16 eggs per day. ки\nStep 2: She eats three for breakfast every morning, so she has 16 - 3 = 13 eggs left. ки\nStep 3: She bakes muffins for her friends every day with four eggs, so she has 13 - 4 = 9 eggs left. ки\nStep 4: She sells the remainder at the farmers' market daily for $2 per fresh duck egg, so she makes 9 * $2 = $17 every day at the farmers' market. The answer is: 17 ки""" # 17 is wrong
full_prompt = [f"{question} {output}" for output in [output1, output2]]


# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8080/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

responses = client.embeddings.create(
    input=full_prompt,
    model=model,
)

for data in responses.data:
    print(torch.tensor(data.embedding).view(-1, 2)[:, 0])
# tensor([0.9956, 0.9956, 0.9985, 0.9956])
# tensor([0.9956, 0.9956, 0.9985, 0.0240])

Of course, you can also use it directly like this:

from vllm import LLM

llm = LLM(
    model='peiyi9979/math-shepherd-mistral-7b-prm', 
    tensor_parallel_size=1, 
    task="embedding",
    pooling_type="STEP",
    pooling_step_tag_id=12902,
    pooling_returned_token_ids=[648, 387],
)

question = """Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"""
output1 = """Step 1: Janet's ducks lay 16 eggs per day. ки\nStep 2: She eats three for breakfast every morning, so she has 16 - 3 = 13 eggs left. ки\nStep 3: She bakes muffins for her friends every day with four eggs, so she has 13 - 4 = 9 eggs left. ки\nStep 4: She sells the remainder at the farmers' market daily for $2 per fresh duck egg, so she makes 9 * $2 = $18 every day at the farmers' market. The answer is: 18 ки""" # 18 is right
output2 = """Step 1: Janet's ducks lay 16 eggs per day. ки\nStep 2: She eats three for breakfast every morning, so she has 16 - 3 = 13 eggs left. ки\nStep 3: She bakes muffins for her friends every day with four eggs, so she has 13 - 4 = 9 eggs left. ки\nStep 4: She sells the remainder at the farmers' market daily for $2 per fresh duck egg, so she makes 9 * $2 = $17 every day at the farmers' market. The answer is: 17 ки""" # 17 is wrong

good_token = '+'
bad_token = '-'
step_tag = 'ки'

tokenizer = llm.get_tokenizer()
[good_token_id, bad_token_id, step_tag_id] = tokenizer.encode(f"{good_token} {bad_token} {step_tag}")[1:] # [648, 387, 12902]


full_prompt = [f"{question} {output}" for output in [output1, output2]]
outputs = llm.encode(full_prompt)

Thank you for your time on reviewing this PR :)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@Went-Liang Went-Liang force-pushed the feature/support_process_reward_model branch from 493b4b2 to d3f0ead Compare October 25, 2024 13:40
vllm/config.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
@Went-Liang Went-Liang force-pushed the feature/support_process_reward_model branch from 3e2c7f4 to e62f65c Compare October 29, 2024 05:13
Comment on lines 394 to 401
self._pooler = Pooler(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else PoolingType.CLS,
normalize=pooler_config.pooling_norm or True,
softmax=pooler_config.pooling_softmax or False,
step_tag_id=pooler_config.pooling_step_tag_id,
returned_token_ids=pooler_config.pooling_returned_token_ids,
)
Copy link
Member

@DarkLight1337 DarkLight1337 Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a factory method to Pooler to automatically merge the config with model specific defaults?

Copy link
Member

@DarkLight1337 DarkLight1337 Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. we should be able to write

self._pooler = Pooler.from_config_with_defaults(
    pooler_config,
    # These values are overridden if they are set inside the config
    pooling_type=PoolingType.CLS,
    normalize=True,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check it

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, thanks for your effort and patience!

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 29, 2024

Now you just have to get the tests to pass.

@Went-Liang Went-Liang force-pushed the feature/support_process_reward_model branch from db2552b to 4e468a3 Compare October 29, 2024 14:53
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 29, 2024 15:46
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2024
auto-merge was automatically disabled October 29, 2024 16:22

Head branch was pushed to by a user without write access

Copy link

mergify bot commented Oct 29, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @Went-Liang please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 29, 2024
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
…e PoolerConfig

Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Went-Liang <wenteng_liang@163.com>
@Went-Liang Went-Liang force-pushed the feature/support_process_reward_model branch from f5434e1 to d1b0f5b Compare October 30, 2024 03:53
@mergify mergify bot removed the needs-rebase label Oct 30, 2024
@Went-Liang
Copy link
Contributor Author

Went-Liang commented Oct 30, 2024

Now you just have to get the tests to pass.

Excuse me, the test produced the following error (as shown in the image). This doesn't seem to be caused by my code changes. Could you please advise on how to handle this? @DarkLight1337

image

@DarkLight1337
Copy link
Member

I have retried the failing test, see if it passes this time.

@DarkLight1337
Copy link
Member

Looks like this issue comes from main branch, I have asked those with permissions to force-merge this.

@Went-Liang
Copy link
Contributor Author

Looks like this issue comes from main branch, I have asked those with permissions to force-merge this.

Thanks so much !!!

@simon-mo simon-mo merged commit 81f09cf into vllm-project:main Oct 30, 2024
61 of 63 checks passed
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: Went-Liang <wenteng_liang@163.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants