Skip to content

Commit bc2ef1f

Browse files
authored
[Model] Support Qwen2.5-Math-RM-72B (#8896)
1 parent 2e7fe7e commit bc2ef1f

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

vllm/model_executor/layers/pooler.py

+7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
class PoolingType(IntEnum):
1212
"""Enumeration for different types of pooling methods."""
1313
LAST = 0
14+
ALL = 1
1415

1516

1617
class Pooler(nn.Module):
@@ -43,6 +44,12 @@ def forward(
4344
if self.pooling_type == PoolingType.LAST:
4445
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
4546
pooled_data = hidden_states[last_token_flat_indices]
47+
elif self.pooling_type == PoolingType.ALL:
48+
offset = 0
49+
pooled_data = []
50+
for prompt_len in prompt_lens:
51+
pooled_data.append(hidden_states[offset:offset + prompt_len])
52+
offset += prompt_len
4653
else:
4754
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
4855

vllm/model_executor/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474

7575
_EMBEDDING_MODELS = {
7676
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
77+
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
7778
}
7879

7980
_MULTIMODAL_MODELS = {
+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# coding=utf-8
2+
# Adapted from
3+
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
4+
# Copyright 2024 The Qwen team.
5+
# Copyright 2023 The vLLM team.
6+
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
7+
from typing import Iterable, List, Optional, Tuple
8+
9+
import torch
10+
from torch import nn
11+
from transformers import Qwen2Config
12+
13+
from vllm.attention import AttentionMetadata
14+
from vllm.config import CacheConfig, LoRAConfig
15+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
16+
RowParallelLinear)
17+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
18+
from vllm.model_executor.layers.quantization.base_config import (
19+
QuantizationConfig)
20+
from vllm.model_executor.model_loader.weight_utils import (
21+
default_weight_loader, maybe_remap_kv_scale_name)
22+
from vllm.model_executor.models.qwen2 import Qwen2Model
23+
from vllm.model_executor.pooling_metadata import PoolingMetadata
24+
from vllm.sequence import IntermediateTensors, PoolerOutput
25+
26+
from .utils import is_pp_missing_parameter
27+
28+
29+
class ReLU(nn.Module):
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.activation = nn.ReLU()
34+
35+
def forward(self, input):
36+
input, _ = input
37+
return self.activation(input)
38+
39+
40+
class Qwen2ForRewardModel(nn.Module):
41+
packed_modules_mapping = {
42+
"qkv_proj": [
43+
"q_proj",
44+
"k_proj",
45+
"v_proj",
46+
],
47+
"gate_up_proj": [
48+
"gate_proj",
49+
"up_proj",
50+
],
51+
}
52+
53+
# LoRA specific attributes
54+
supported_lora_modules = [
55+
"qkv_proj",
56+
"o_proj",
57+
"gate_up_proj",
58+
"down_proj",
59+
]
60+
embedding_modules = {}
61+
embedding_padding_modules = []
62+
63+
def __init__(
64+
self,
65+
config: Qwen2Config,
66+
cache_config: Optional[CacheConfig] = None,
67+
quant_config: Optional[QuantizationConfig] = None,
68+
lora_config: Optional[LoRAConfig] = None,
69+
) -> None:
70+
# TODO (@robertgshaw2): see if this can be moved out
71+
if (cache_config.sliding_window is not None
72+
and hasattr(config, "max_window_layers")):
73+
raise ValueError("Sliding window for some but all layers is not "
74+
"supported. This model uses sliding window "
75+
"but `max_window_layers` = %s is less than "
76+
"`num_hidden_layers` = %s. Please open an issue "
77+
"to discuss this feature." % (
78+
config.max_window_layers,
79+
config.num_hidden_layers,
80+
))
81+
82+
super().__init__()
83+
84+
self.config = config
85+
self.lora_config = lora_config
86+
87+
self.quant_config = quant_config
88+
self.model = Qwen2Model(config, cache_config, quant_config)
89+
90+
self.score = nn.Sequential(
91+
ColumnParallelLinear(config.hidden_size,
92+
config.hidden_size,
93+
quant_config=quant_config),
94+
ReLU(),
95+
RowParallelLinear(config.hidden_size, 1,
96+
quant_config=quant_config),
97+
)
98+
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
99+
100+
def forward(
101+
self,
102+
input_ids: torch.Tensor,
103+
positions: torch.Tensor,
104+
kv_caches: List[torch.Tensor],
105+
attn_metadata: AttentionMetadata,
106+
intermediate_tensors: Optional[IntermediateTensors] = None,
107+
) -> torch.Tensor:
108+
hidden_states = self.model(input_ids, positions, kv_caches,
109+
attn_metadata, intermediate_tensors)
110+
logits, _ = self.score(hidden_states)
111+
return logits
112+
113+
def pooler(
114+
self,
115+
hidden_states: torch.Tensor,
116+
pooling_metadata: PoolingMetadata,
117+
) -> Optional[PoolerOutput]:
118+
return self._pooler(hidden_states, pooling_metadata)
119+
120+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
121+
stacked_params_mapping = [
122+
# (param_name, shard_name, shard_id)
123+
("qkv_proj", "q_proj", "q"),
124+
("qkv_proj", "k_proj", "k"),
125+
("qkv_proj", "v_proj", "v"),
126+
("gate_up_proj", "gate_proj", 0),
127+
("gate_up_proj", "up_proj", 1),
128+
]
129+
params_dict = dict(self.named_parameters(remove_duplicate=False))
130+
for name, loaded_weight in weights:
131+
# Skip loading lm_head for embedding model
132+
if name == "lm_head.weight":
133+
continue
134+
if "rotary_emb.inv_freq" in name:
135+
continue
136+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
137+
if weight_name not in name:
138+
continue
139+
name = name.replace(weight_name, param_name)
140+
# Skip loading extra bias for GPTQ models.
141+
if name.endswith(".bias") and name not in params_dict:
142+
continue
143+
if is_pp_missing_parameter(name, self):
144+
continue
145+
param = params_dict[name]
146+
weight_loader = param.weight_loader
147+
weight_loader(param, loaded_weight, shard_id)
148+
break
149+
else:
150+
# Skip loading extra bias for GPTQ models.
151+
if name.endswith(".bias") and name not in params_dict:
152+
continue
153+
# Remapping the name of FP8 kv-scale.
154+
name = maybe_remap_kv_scale_name(name, params_dict)
155+
if name is None:
156+
continue
157+
if is_pp_missing_parameter(name, self):
158+
continue
159+
param = params_dict[name]
160+
weight_loader = getattr(param, "weight_loader",
161+
default_weight_loader)
162+
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)