|
| 1 | +# coding=utf-8 |
| 2 | +# Adapted from |
| 3 | +# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py |
| 4 | +# Copyright 2024 The Qwen team. |
| 5 | +# Copyright 2023 The vLLM team. |
| 6 | +"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" |
| 7 | +from typing import Iterable, List, Optional, Tuple |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch import nn |
| 11 | +from transformers import Qwen2Config |
| 12 | + |
| 13 | +from vllm.attention import AttentionMetadata |
| 14 | +from vllm.config import CacheConfig, LoRAConfig |
| 15 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 16 | + RowParallelLinear) |
| 17 | +from vllm.model_executor.layers.pooler import Pooler, PoolingType |
| 18 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 19 | + QuantizationConfig) |
| 20 | +from vllm.model_executor.model_loader.weight_utils import ( |
| 21 | + default_weight_loader, maybe_remap_kv_scale_name) |
| 22 | +from vllm.model_executor.models.qwen2 import Qwen2Model |
| 23 | +from vllm.model_executor.pooling_metadata import PoolingMetadata |
| 24 | +from vllm.sequence import IntermediateTensors, PoolerOutput |
| 25 | + |
| 26 | +from .utils import is_pp_missing_parameter |
| 27 | + |
| 28 | + |
| 29 | +class ReLU(nn.Module): |
| 30 | + |
| 31 | + def __init__(self): |
| 32 | + super().__init__() |
| 33 | + self.activation = nn.ReLU() |
| 34 | + |
| 35 | + def forward(self, input): |
| 36 | + input, _ = input |
| 37 | + return self.activation(input) |
| 38 | + |
| 39 | + |
| 40 | +class Qwen2ForRewardModel(nn.Module): |
| 41 | + packed_modules_mapping = { |
| 42 | + "qkv_proj": [ |
| 43 | + "q_proj", |
| 44 | + "k_proj", |
| 45 | + "v_proj", |
| 46 | + ], |
| 47 | + "gate_up_proj": [ |
| 48 | + "gate_proj", |
| 49 | + "up_proj", |
| 50 | + ], |
| 51 | + } |
| 52 | + |
| 53 | + # LoRA specific attributes |
| 54 | + supported_lora_modules = [ |
| 55 | + "qkv_proj", |
| 56 | + "o_proj", |
| 57 | + "gate_up_proj", |
| 58 | + "down_proj", |
| 59 | + ] |
| 60 | + embedding_modules = {} |
| 61 | + embedding_padding_modules = [] |
| 62 | + |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + config: Qwen2Config, |
| 66 | + cache_config: Optional[CacheConfig] = None, |
| 67 | + quant_config: Optional[QuantizationConfig] = None, |
| 68 | + lora_config: Optional[LoRAConfig] = None, |
| 69 | + ) -> None: |
| 70 | + # TODO (@robertgshaw2): see if this can be moved out |
| 71 | + if (cache_config.sliding_window is not None |
| 72 | + and hasattr(config, "max_window_layers")): |
| 73 | + raise ValueError("Sliding window for some but all layers is not " |
| 74 | + "supported. This model uses sliding window " |
| 75 | + "but `max_window_layers` = %s is less than " |
| 76 | + "`num_hidden_layers` = %s. Please open an issue " |
| 77 | + "to discuss this feature." % ( |
| 78 | + config.max_window_layers, |
| 79 | + config.num_hidden_layers, |
| 80 | + )) |
| 81 | + |
| 82 | + super().__init__() |
| 83 | + |
| 84 | + self.config = config |
| 85 | + self.lora_config = lora_config |
| 86 | + |
| 87 | + self.quant_config = quant_config |
| 88 | + self.model = Qwen2Model(config, cache_config, quant_config) |
| 89 | + |
| 90 | + self.score = nn.Sequential( |
| 91 | + ColumnParallelLinear(config.hidden_size, |
| 92 | + config.hidden_size, |
| 93 | + quant_config=quant_config), |
| 94 | + ReLU(), |
| 95 | + RowParallelLinear(config.hidden_size, 1, |
| 96 | + quant_config=quant_config), |
| 97 | + ) |
| 98 | + self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) |
| 99 | + |
| 100 | + def forward( |
| 101 | + self, |
| 102 | + input_ids: torch.Tensor, |
| 103 | + positions: torch.Tensor, |
| 104 | + kv_caches: List[torch.Tensor], |
| 105 | + attn_metadata: AttentionMetadata, |
| 106 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 107 | + ) -> torch.Tensor: |
| 108 | + hidden_states = self.model(input_ids, positions, kv_caches, |
| 109 | + attn_metadata, intermediate_tensors) |
| 110 | + logits, _ = self.score(hidden_states) |
| 111 | + return logits |
| 112 | + |
| 113 | + def pooler( |
| 114 | + self, |
| 115 | + hidden_states: torch.Tensor, |
| 116 | + pooling_metadata: PoolingMetadata, |
| 117 | + ) -> Optional[PoolerOutput]: |
| 118 | + return self._pooler(hidden_states, pooling_metadata) |
| 119 | + |
| 120 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 121 | + stacked_params_mapping = [ |
| 122 | + # (param_name, shard_name, shard_id) |
| 123 | + ("qkv_proj", "q_proj", "q"), |
| 124 | + ("qkv_proj", "k_proj", "k"), |
| 125 | + ("qkv_proj", "v_proj", "v"), |
| 126 | + ("gate_up_proj", "gate_proj", 0), |
| 127 | + ("gate_up_proj", "up_proj", 1), |
| 128 | + ] |
| 129 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 130 | + for name, loaded_weight in weights: |
| 131 | + # Skip loading lm_head for embedding model |
| 132 | + if name == "lm_head.weight": |
| 133 | + continue |
| 134 | + if "rotary_emb.inv_freq" in name: |
| 135 | + continue |
| 136 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 137 | + if weight_name not in name: |
| 138 | + continue |
| 139 | + name = name.replace(weight_name, param_name) |
| 140 | + # Skip loading extra bias for GPTQ models. |
| 141 | + if name.endswith(".bias") and name not in params_dict: |
| 142 | + continue |
| 143 | + if is_pp_missing_parameter(name, self): |
| 144 | + continue |
| 145 | + param = params_dict[name] |
| 146 | + weight_loader = param.weight_loader |
| 147 | + weight_loader(param, loaded_weight, shard_id) |
| 148 | + break |
| 149 | + else: |
| 150 | + # Skip loading extra bias for GPTQ models. |
| 151 | + if name.endswith(".bias") and name not in params_dict: |
| 152 | + continue |
| 153 | + # Remapping the name of FP8 kv-scale. |
| 154 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 155 | + if name is None: |
| 156 | + continue |
| 157 | + if is_pp_missing_parameter(name, self): |
| 158 | + continue |
| 159 | + param = params_dict[name] |
| 160 | + weight_loader = getattr(param, "weight_loader", |
| 161 | + default_weight_loader) |
| 162 | + weight_loader(param, loaded_weight) |
0 commit comments