Skip to content

Commit

Permalink
Support math-shepherd-mistral-7b-prm model
Browse files Browse the repository at this point in the history
Signed-off-by: Went-Liang <wenteng_liang@163.com>
  • Loading branch information
Went-Liang committed Oct 25, 2024
1 parent 9645b9f commit d3f0ead
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
19 changes: 19 additions & 0 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from enum import IntEnum

import torch
Expand All @@ -13,6 +14,7 @@ class PoolingType(IntEnum):
LAST = 0
ALL = 1
CLS = 2
STEP = 3


class Pooler(nn.Module):
Expand All @@ -33,6 +35,9 @@ def __init__(self, pooling_type: PoolingType, normalize: bool):

self.pooling_type = pooling_type
self.normalize = normalize
returned_token_ids = os.environ.get('RETURNED_TOKEN_IDS', '648,387')
self.returned_token_ids = list(map(int, returned_token_ids.split(",")))
self.step_tag_id = int(os.environ.get('STEP_TOKEN_ID', -1))

def forward(
self,
Expand All @@ -58,6 +63,20 @@ def forward(
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.STEP:
logits = hidden_states[:, self.returned_token_ids].softmax(dim=-1)
offset = 0
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
if self.step_tag_id == -1:
pooled_data.append(logits[offset:offset + prompt_len])
else:
step_idxs = torch.tensor(
seq_data_i.prompt_token_ids) == self.step_tag_id
pooled_data.append(logits[offset:offset +
prompt_len][step_idxs])
offset += prompt_len
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")

Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self._pooler = Pooler(pooling_type=PoolingType.STEP, normalize=False)

def forward(
self,
Expand All @@ -562,6 +563,14 @@ def compute_logits(
sampling_metadata)
return logits

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)

def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
Expand Down

0 comments on commit d3f0ead

Please sign in to comment.