diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282f..6507e63a05319 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,3 +1,4 @@ +import os from enum import IntEnum import torch @@ -13,6 +14,7 @@ class PoolingType(IntEnum): LAST = 0 ALL = 1 CLS = 2 + STEP = 3 class Pooler(nn.Module): @@ -33,6 +35,9 @@ def __init__(self, pooling_type: PoolingType, normalize: bool): self.pooling_type = pooling_type self.normalize = normalize + returned_token_ids = os.environ.get('RETURNED_TOKEN_IDS', '648,387') + self.returned_token_ids = list(map(int, returned_token_ids.split(","))) + self.step_tag_id = int(os.environ.get('STEP_TOKEN_ID', -1)) def forward( self, @@ -58,6 +63,20 @@ def forward( for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len + elif self.pooling_type == PoolingType.STEP: + logits = hidden_states[:, self.returned_token_ids].softmax(dim=-1) + offset = 0 + pooled_data = [] + for prompt_len, seq_data_i in zip( + prompt_lens, pooling_metadata.seq_data.values()): + if self.step_tag_id == -1: + pooled_data.append(logits[offset:offset + prompt_len]) + else: + step_idxs = torch.tensor( + seq_data_i.prompt_token_ids) == self.step_tag_id + pooled_data.append(logits[offset:offset + + prompt_len][step_idxs]) + offset += prompt_len else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c346e3e808e3f..33418db16c087 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -540,6 +540,7 @@ def __init__( self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self._pooler = Pooler(pooling_type=PoolingType.STEP, normalize=False) def forward( self, @@ -562,6 +563,14 @@ def compute_logits( sampling_metadata) return logits + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + logits = self.compute_logits(hidden_states, None) + return self._pooler(logits, pooling_metadata) + def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata)