|
18 | 18 | QKVParallelLinear,
|
19 | 19 | RowParallelLinear)
|
20 | 20 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 21 | +from vllm.model_executor.layers.pooler import Pooler, PoolingType |
21 | 22 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
22 | 23 | from vllm.model_executor.layers.rotary_embedding import get_rope
|
23 | 24 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
24 | 25 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
25 | 26 | ParallelLMHead, VocabParallelEmbedding)
|
26 | 27 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 28 | +from vllm.model_executor.pooling_metadata import PoolingMetadata |
27 | 29 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
28 |
| -from vllm.sequence import IntermediateTensors |
| 30 | +from vllm.sequence import IntermediateTensors, PoolerOutput |
29 | 31 |
|
30 | 32 | from .interfaces import SupportsLoRA, SupportsPP
|
31 | 33 | from .utils import (is_pp_missing_parameter,
|
@@ -433,3 +435,59 @@ def load_weights(self, weights: Iterable[Tuple[str,
|
433 | 435 | weight_loader(param, loaded_weight)
|
434 | 436 | loaded_params.add(name)
|
435 | 437 | return loaded_params
|
| 438 | + |
| 439 | + |
| 440 | +class InternLM2ForRewardModel(InternLM2ForCausalLM): |
| 441 | + |
| 442 | + def __init__( |
| 443 | + self, |
| 444 | + *, |
| 445 | + vllm_config: VllmConfig, |
| 446 | + prefix: str = "", |
| 447 | + model_type: Type[InternLM2Model] = InternLM2Model, |
| 448 | + ): |
| 449 | + super().__init__(vllm_config=vllm_config, |
| 450 | + prefix=prefix, |
| 451 | + model_type=model_type) |
| 452 | + |
| 453 | + for attr in ("output", "logits_processor", "sampler"): |
| 454 | + delattr(self, attr) |
| 455 | + |
| 456 | + config = vllm_config.model_config.hf_config |
| 457 | + self.v_head = RowParallelLinear( |
| 458 | + config.hidden_size, |
| 459 | + 1, |
| 460 | + bias=False, |
| 461 | + input_is_parallel=False, |
| 462 | + prefix=maybe_prefix(prefix, "v_head"), |
| 463 | + ) |
| 464 | + |
| 465 | + pooler_config = vllm_config.model_config.pooler_config |
| 466 | + self._pooler = Pooler.from_config_with_defaults( |
| 467 | + pooler_config, |
| 468 | + pooling_type=PoolingType.ALL, |
| 469 | + normalize=False, |
| 470 | + softmax=False, |
| 471 | + ) |
| 472 | + |
| 473 | + def forward( |
| 474 | + self, |
| 475 | + input_ids: torch.Tensor, |
| 476 | + positions: torch.Tensor, |
| 477 | + kv_caches: List[torch.Tensor], |
| 478 | + attn_metadata: AttentionMetadata, |
| 479 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 480 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 481 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 482 | + hidden_states = self.model(input_ids, positions, kv_caches, |
| 483 | + attn_metadata, intermediate_tensors, |
| 484 | + inputs_embeds) |
| 485 | + logits, _ = self.v_head(hidden_states) |
| 486 | + return logits |
| 487 | + |
| 488 | + def pooler( |
| 489 | + self, |
| 490 | + hidden_states: torch.Tensor, |
| 491 | + pooling_metadata: PoolingMetadata, |
| 492 | + ) -> Optional[PoolerOutput]: |
| 493 | + return self._pooler(hidden_states, pooling_metadata) |
0 commit comments