Skip to content

add reward model api #3665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,
GenerateRequest, GenerateResponse, LogProbs, ModelCard, ModelList,
ModelPermission, TopLogprob, UpdateParamsRequest, UsageInfo)
ModelPermission, PoolingRequest, PoolingResponse, TopLogprob,
UpdateParamsRequest, UsageInfo)
from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager
from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
Expand Down Expand Up @@ -871,6 +872,59 @@ def encode(prompt: str, do_preprocess: bool, add_bos: bool):
return EncodeResponse(input_ids=encoded, length=length)


@router.post('/pooling')
async def pooling(request: PoolingRequest, raw_request: Request = None):
"""Pooling prompts for reward model.

In vLLM documentation, https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#pooling-api_1,
the input format of Pooling API is the same as Embeddings API.

Go to https://platform.openai.com/docs/api-reference/embeddings/create
for the Embeddings API specification.

The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- input (List[int] | List[List[int]] | str | List[str]): input text to be embed
"""

async_engine = VariableInterface.async_engine

request_input = request.input
model_name = request.model or async_engine.model_name

# Normalize all inputs to be a batch (List[List[int]])
if isinstance(request_input, str):
input_ids = [async_engine.tokenizer.encode(request_input, add_special_tokens=False)]
elif isinstance(request_input, List):
if not request_input:
return create_error_response(HTTPStatus.BAD_REQUEST, 'Input list cannot be empty.')
if isinstance(request_input[0], str): # List[str]
input_ids = [async_engine.tokenizer.encode(p, add_special_tokens=False) for p in request_input]
elif isinstance(request_input[0], int): # List[int]
input_ids = [request_input]
elif isinstance(request_input[0], List): # List[List[int]]
input_ids = request_input
else:
return create_error_response(HTTPStatus.BAD_REQUEST, 'Input list contains an invalid type.')
else:
return create_error_response(HTTPStatus.BAD_REQUEST, 'Input must be a string or a list.')

batch_scores = await async_engine._async_get_reward_score(input_ids)
prompt_tokens = sum(len(ids) for ids in input_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens)

data = []
for i, score in enumerate(batch_scores):
data.append({
'index': i,
'object': 'pooling',
'data': score,
})

response = PoolingResponse(model=model_name, data=data, usage=usage)
return response.model_dump()


@router.post('/update_weights', dependencies=[Depends(check_api_key)])
def update_params(request: UpdateParamsRequest, raw_request: Request = None):
"""Update weights for the model."""
Expand Down
27 changes: 27 additions & 0 deletions lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,33 @@ class EmbeddingsResponse(BaseModel):
usage: UsageInfo


class PoolingRequest(BaseModel):
"""Pooling request.

Currently we follow vLLM API protocol,
https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L1174

Notice that ideally we should reuse the input format of embedding API
https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L1174
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L383
"""
model: Optional[str] = None
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Literal['float', 'base64'] = 'float'
dimensions: Optional[int] = None
user: Optional[str] = None


class PoolingResponse(BaseModel):
"""Pooling response."""
id: str = Field(default_factory=lambda: f'pool-{shortuuid.random()}')
object: str = 'list'
created: int = Field(default_factory=lambda: int(time.time()))
model: str = None
data: List[Dict[str, Any]]
usage: UsageInfo


class EncodeRequest(BaseModel):
"""Encode request."""
input: Union[str, List[str]]
Expand Down
18 changes: 17 additions & 1 deletion lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_reward_score(self, input_ids: List) -> List[float]:
"""
supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']
if self.arch not in supported_reward_models:
raise ValueError(f'{self.arch} is not in reward mode list: {supported_reward_models}')
raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}')
assert isinstance(input_ids, List)
assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
# Make input_ids a list of token_id list
Expand All @@ -40,6 +40,22 @@ def get_reward_score(self, input_ids: List) -> List[float]:
scores = [x[-1].cpu().item() for x in logits]
return scores

async def _async_get_reward_score(self, input_ids: List) -> List[float]:
"""Async version of get_reward_score."""
supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']
if self.arch not in supported_reward_models:
raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}')
assert isinstance(input_ids, List)
assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
# Make input_ids a list of token_id list
input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids

logits = await self._async_get_logits(input_ids=input_ids)

logits = [x.squeeze() for x in logits]
scores = [x[-1].cpu().item() for x in logits]
return scores

async def _async_get_logits(self,
input_ids,
steps: List[int] = None,
Expand Down
Loading