44# https://github.com/THUDM/CogAgent
55"""Inference-only CogAgent model compatible with THUDM weights."""
66from argparse import Namespace
7- from typing import (Iterable , List , Mapping , Optional , Sequence , Set , Tuple ,
8- TypedDict , Union )
7+ from typing import (Iterable , List , Mapping , Optional , Set , Tuple , TypedDict ,
8+ Union )
99
1010import torch
1111from torch import nn
1919from vllm .attention import Attention , AttentionMetadata
2020from vllm .config import CacheConfig , VllmConfig
2121from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
22- from vllm .logger import init_logger
2322from vllm .model_executor .layers .activation import SiluAndMul
2423from vllm .model_executor .layers .layernorm import RMSNorm
2524from vllm .model_executor .layers .linear import (MergedColumnParallelLinear ,
3736from vllm .model_executor .sampling_metadata import SamplingMetadata
3837from vllm .multimodal import MULTIMODAL_REGISTRY
3938from vllm .multimodal .inputs import MultiModalKwargs , NestedTensors
40- from vllm .multimodal .parse import ImageSize , MultiModalDataItems
39+ from vllm .multimodal .parse import MultiModalDataItems
4140from vllm .multimodal .processing import (BaseMultiModalProcessor ,
4241 BaseProcessingInfo , BatchFeature ,
43- BoundPromptReplacement ,
4442 MultiModalFieldConfig ,
45- PlaceholderFeaturesInfo ,
4643 PromptReplacement )
4744from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
4845from vllm .sequence import IntermediateTensors
5350 make_empty_intermediate_tensors_factory , make_layers ,
5451 maybe_prefix , merge_multimodal_embeddings )
5552
56- logger = init_logger (__name__ )
57-
58- IMAGE_TOKEN_ID = 151329
59-
60-
61- def build_normalization_transform (image_size : int ) -> transforms .Compose :
62- """
63- Build a normalization transform which can be applied to one or
64- more input images from which we want to extract visual features.
65-
66- Args:
67- image_size: size of the image to be processed for visual embeddings.
68-
69- Returns:
70- Callable transform for normalizing and resizing one RGB image.
71- """
72-
73- return transforms .Compose ([
74- transforms .Resize (
75- (image_size , image_size ),
76- interpolation = InterpolationMode .BICUBIC ,
77- ),
78- transforms .ToTensor (),
79- transforms .Normalize (
80- (0.48145466 , 0.4578275 , 0.40821073 ),
81- (0.26862954 , 0.26130258 , 0.27577711 ),
82- ),
83- ])
84-
85-
86- def calculate_image_placeholder (vision_config ):
87- return (vision_config ["image_size" ] // vision_config ["patch_size" ] // 2 )** 2
88-
8953
9054class GLMImagePixelInputs (TypedDict ):
9155 pixel_values : torch .Tensor
@@ -109,9 +73,20 @@ def __init__(
10973 self .config = config
11074 self .tokenizer = tokenizer
11175
112- if hasattr (self .config , "vision_config" ):
113- self .image_transform = build_normalization_transform (
114- config .vision_config ["image_size" ])
76+ if vision_config := getattr (config , "vision_config" , None ):
77+ image_size = vision_config ["image_size" ]
78+
79+ self .image_transform = transforms .Compose ([
80+ transforms .Resize (
81+ (image_size , image_size ),
82+ interpolation = InterpolationMode .BICUBIC ,
83+ ),
84+ transforms .ToTensor (),
85+ transforms .Normalize (
86+ mean = (0.48145466 , 0.4578275 , 0.40821073 ),
87+ std = (0.26862954 , 0.26130258 , 0.27577711 ),
88+ ),
89+ ])
11590 else :
11691 self .image_transform = None
11792
@@ -150,9 +125,19 @@ def __call__(
150125
151126class GLM4VProcessingInfo (BaseProcessingInfo ):
152127
153- def __init__ (self , ctx ):
154- super ().__init__ (ctx )
155- self ._pre_calculate ()
128+ def get_tokenizer (self ):
129+ tokenizer = self .ctx .tokenizer
130+ assert isinstance (tokenizer , PreTrainedTokenizer )
131+ return tokenizer
132+
133+ def get_hf_config (self ):
134+ return self .ctx .get_hf_config (ChatGLMConfig )
135+
136+ def get_hf_processor (self ) -> GLM4VProcessor :
137+ return GLM4VProcessor (
138+ self .get_hf_config (),
139+ self .get_tokenizer (),
140+ )
156141
157142 def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
158143 return {"image" : 1 }
@@ -162,27 +147,21 @@ def get_mm_max_tokens_per_item(
162147 seq_len : int ,
163148 mm_counts : Mapping [str , int ],
164149 ) -> Mapping [str , int ]:
150+ return {"image" : self .get_num_image_feature_tokens ()}
165151
166- return {"image" : self .image_token_num + 2 }
167-
168- def _pre_calculate (self ):
152+ def get_num_image_tokens (self ) -> int :
169153 hf_config = self .get_hf_config ()
170- vision_config = hf_config .vision_config
171- self .image_token_num = calculate_image_placeholder (vision_config )
172- self .image_size = vision_config ["image_size" ]
154+ if not (vision_config := getattr (hf_config , "vision_config" , None )):
155+ return 0
173156
174- def get_num_image_tokens (self ) -> int :
175- return self .image_token_num + 2
157+ image_size = vision_config ["image_size" ]
158+ patch_size = vision_config ["patch_size" ]
159+ grid_length = image_size // patch_size // 2
160+ return grid_length * grid_length
176161
177- def get_image_size (self ) -> ImageSize :
178-
179- return ImageSize (height = self .image_size , width = self .image_size )
180-
181- def get_hf_processor (self ) -> GLM4VProcessor :
182- return GLM4VProcessor (
183- self .get_hf_config (),
184- self .get_tokenizer (),
185- )
162+ def get_num_image_feature_tokens (self ) -> int :
163+ # EVA2CLIPModel has embeddings for boi and eoi tokens as well
164+ return self .get_num_image_tokens () + 2
186165
187166
188167class GLM4VDummyInputsBuilder (BaseDummyInputsBuilder [GLM4VProcessingInfo ]):
@@ -192,18 +171,24 @@ def get_dummy_processor_inputs(
192171 seq_len : int ,
193172 mm_counts : Mapping [str , int ],
194173 ) -> ProcessorInputs :
174+ hf_config = self .info .get_hf_config ()
175+ if not (vision_config := getattr (hf_config , "vision_config" , None )):
176+ return ProcessorInputs (prompt_text = "" , mm_data = {})
177+
178+ target_width = target_height = vision_config ["image_size" ]
195179 num_images = mm_counts .get ("image" , 0 )
196- target_width , target_height = self .info .get_image_size ()
197180
198181 mm_data = {
199182 "image" :
200183 self ._get_dummy_images (width = target_width ,
201184 height = target_height ,
202185 num_images = num_images )
203186 }
204- text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
187+
188+ base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
189+
205190 return ProcessorInputs (
206- prompt_text = text ,
191+ prompt_text = base_text * num_images ,
207192 mm_data = mm_data ,
208193 )
209194
@@ -223,47 +208,28 @@ def _get_prompt_replacements(
223208 hf_processor_mm_kwargs : Mapping [str , object ],
224209 out_mm_kwargs : MultiModalKwargs ,
225210 ) -> list [PromptReplacement ]:
211+ hf_config = self .info .get_hf_config ()
212+ if not hasattr (hf_config , "vision_config" ):
213+ return []
214+
215+ boi_token_id = hf_config .boi_token_id
216+ image_token_id = hf_config .pad_token_id
217+ eoi_token_id = hf_config .eoi_token_id
226218
227219 def get_replacement (item_idx : int ):
228- image_tokens = self .info .image_token_num
229- return [IMAGE_TOKEN_ID ] * image_tokens
220+ num_image_tokens = self .info .get_num_image_tokens ()
221+ image_tokens = [image_token_id ] * num_image_tokens
222+
223+ return [boi_token_id ] + image_tokens + [eoi_token_id ]
230224
231225 return [
232226 PromptReplacement (
233227 modality = "image" ,
234- target = [IMAGE_TOKEN_ID ],
228+ target = [boi_token_id , image_token_id , eoi_token_id ],
235229 replacement = get_replacement ,
236230 ),
237231 ]
238232
239- def _apply_prompt_replacements (
240- self ,
241- token_ids : list [int ],
242- mm_prompt_repls : Mapping [str , Sequence [BoundPromptReplacement ]],
243- mm_item_counts : Mapping [str , int ],
244- ) -> tuple [list [int ], str , Mapping [str , list [PlaceholderFeaturesInfo ]]]:
245- token_ids , text , placeholders = super ()._apply_prompt_replacements (
246- token_ids = token_ids ,
247- mm_prompt_repls = mm_prompt_repls ,
248- mm_item_counts = mm_item_counts ,
249- )
250- hf_config = self .info .get_hf_config ()
251- boi_token_id = hf_config .boi_token_id
252- eoi_token_id = hf_config .eoi_token_id
253- placeholders = {
254- modality : [
255- PlaceholderFeaturesInfo (
256- modality = p .modality ,
257- item_idx = p .item_idx ,
258- start_idx = p .start_idx - 1 ,
259- tokens = [boi_token_id ] + p .tokens + [eoi_token_id ],
260- ) for p in ps
261- ]
262- for modality , ps in placeholders .items ()
263- }
264-
265- return token_ids , text , placeholders
266-
267233
268234class GLMAttention (nn .Module ):
269235
@@ -618,7 +584,7 @@ def get_input_embeddings(
618584 multimodal_embeddings = multimodal_embeddings ,
619585 placeholder_token_id = [
620586 self .config .boi_token_id ,
621- IMAGE_TOKEN_ID ,
587+ self . config . pad_token_id ,
622588 self .config .eoi_token_id ,
623589 ],
624590 )
0 commit comments