Skip to content

Commit d34be24

Browse files
[Model] Support InternLM2 Reward models (#11571)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent b5cbe8e commit d34be24

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
450450
- Example HF Models
451451
- :ref:`LoRA <lora-adapter>`
452452
- :ref:`PP <distributed-serving>`
453+
* - :code:`InternLM2ForRewardModel`
454+
- InternLM2-based
455+
- :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc.
456+
- ✅︎
457+
- ✅︎
453458
* - :code:`LlamaForCausalLM`
454459
- Llama-based
455460
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class _HfExamplesInfo:
140140
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
141141
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
142142
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
143+
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
144+
trust_remote_code=True),
143145
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
144146
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
145147
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),

vllm/model_executor/models/internlm2.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
QKVParallelLinear,
1919
RowParallelLinear)
2020
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
2122
from vllm.model_executor.layers.quantization import QuantizationConfig
2223
from vllm.model_executor.layers.rotary_embedding import get_rope
2324
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2425
from vllm.model_executor.layers.vocab_parallel_embedding import (
2526
ParallelLMHead, VocabParallelEmbedding)
2627
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28+
from vllm.model_executor.pooling_metadata import PoolingMetadata
2729
from vllm.model_executor.sampling_metadata import SamplingMetadata
28-
from vllm.sequence import IntermediateTensors
30+
from vllm.sequence import IntermediateTensors, PoolerOutput
2931

3032
from .interfaces import SupportsLoRA, SupportsPP
3133
from .utils import (is_pp_missing_parameter,
@@ -433,3 +435,59 @@ def load_weights(self, weights: Iterable[Tuple[str,
433435
weight_loader(param, loaded_weight)
434436
loaded_params.add(name)
435437
return loaded_params
438+
439+
440+
class InternLM2ForRewardModel(InternLM2ForCausalLM):
441+
442+
def __init__(
443+
self,
444+
*,
445+
vllm_config: VllmConfig,
446+
prefix: str = "",
447+
model_type: Type[InternLM2Model] = InternLM2Model,
448+
):
449+
super().__init__(vllm_config=vllm_config,
450+
prefix=prefix,
451+
model_type=model_type)
452+
453+
for attr in ("output", "logits_processor", "sampler"):
454+
delattr(self, attr)
455+
456+
config = vllm_config.model_config.hf_config
457+
self.v_head = RowParallelLinear(
458+
config.hidden_size,
459+
1,
460+
bias=False,
461+
input_is_parallel=False,
462+
prefix=maybe_prefix(prefix, "v_head"),
463+
)
464+
465+
pooler_config = vllm_config.model_config.pooler_config
466+
self._pooler = Pooler.from_config_with_defaults(
467+
pooler_config,
468+
pooling_type=PoolingType.ALL,
469+
normalize=False,
470+
softmax=False,
471+
)
472+
473+
def forward(
474+
self,
475+
input_ids: torch.Tensor,
476+
positions: torch.Tensor,
477+
kv_caches: List[torch.Tensor],
478+
attn_metadata: AttentionMetadata,
479+
intermediate_tensors: Optional[IntermediateTensors] = None,
480+
inputs_embeds: Optional[torch.Tensor] = None,
481+
) -> Union[torch.Tensor, IntermediateTensors]:
482+
hidden_states = self.model(input_ids, positions, kv_caches,
483+
attn_metadata, intermediate_tensors,
484+
inputs_embeds)
485+
logits, _ = self.v_head(hidden_states)
486+
return logits
487+
488+
def pooler(
489+
self,
490+
hidden_states: torch.Tensor,
491+
pooling_metadata: PoolingMetadata,
492+
) -> Optional[PoolerOutput]:
493+
return self._pooler(hidden_states, pooling_metadata)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
114114
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
115115
"GritLM": ("gritlm", "GritLM"),
116+
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
116117
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
117118
"LlamaModel": ("llama", "LlamaForCausalLM"),
118119
**{

0 commit comments

Comments
 (0)