5
5
# Licensed under The MIT License [see LICENSE for details]
6
6
# --------------------------------------------------------
7
7
import re
8
- from typing import (Iterable , List , Literal , Mapping , Optional , Tuple ,
9
- TypedDict , Union )
8
+ from functools import partial
9
+ from typing import (Any , Dict , Iterable , List , Literal , Mapping , Optional ,
10
+ Tuple , TypedDict , Union )
10
11
11
12
import torch
12
13
import torch .nn as nn
@@ -122,6 +123,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
122
123
return blocks , target_width , target_height
123
124
124
125
126
+ def calculate_num_blocks_wrapper (hf_config : Dict [str , Any ],
127
+ max_dynamic_patch : Optional [int ] = None ):
128
+ if max_dynamic_patch is None :
129
+ max_dynamic_patch = hf_config .max_dynamic_patch
130
+ min_num = hf_config .min_dynamic_patch
131
+ image_size = hf_config .vision_config .image_size
132
+ use_thumbnail = hf_config .use_thumbnail
133
+ return partial (calculate_num_blocks ,
134
+ min_num = min_num ,
135
+ max_num = max_dynamic_patch ,
136
+ image_size = image_size ,
137
+ use_thumbnail = use_thumbnail )
138
+
139
+
125
140
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
126
141
def dynamic_preprocess (image : Image .Image , min_num : int , max_num : int ,
127
142
image_size : int ,
@@ -168,62 +183,85 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
168
183
return pixel_values
169
184
170
185
171
- def get_internvl_num_patches (image_size : int , patch_size : int ,
172
- downsample_ratio : float ):
186
+ def image_to_pixel_values_wrapper (hf_config : Dict [str , Any ],
187
+ max_dynamic_patch : Optional [int ] = None ):
188
+ image_size = hf_config .vision_config .image_size
189
+ min_num = hf_config .min_dynamic_patch
190
+ if max_dynamic_patch is None :
191
+ max_dynamic_patch = hf_config .max_dynamic_patch
192
+ use_thumbnail = hf_config .use_thumbnail
193
+ return partial (image_to_pixel_values ,
194
+ input_size = image_size ,
195
+ min_num = min_num ,
196
+ max_num = max_dynamic_patch ,
197
+ use_thumbnail = use_thumbnail )
198
+
199
+
200
+ def get_internvl_num_patches (hf_config : Dict [str , Any ]):
201
+ vision_config = hf_config .vision_config
202
+ downsample_ratio = hf_config .downsample_ratio
203
+ image_size = vision_config .image_size
204
+ patch_size = vision_config .patch_size
173
205
return int (
174
206
get_clip_num_patches (image_size = image_size , patch_size = patch_size ) *
175
207
(downsample_ratio ** 2 ))
176
208
177
209
178
- def get_max_internvl_image_tokens (ctx : InputContext ):
210
+ def get_max_internvl_image_tokens (ctx : InputContext ,
211
+ * ,
212
+ max_dynamic_patch : Optional [int ] = None ):
179
213
hf_config = ctx .get_hf_config ()
180
- vision_config = hf_config .vision_config
181
214
215
+ if max_dynamic_patch is None :
216
+ max_dynamic_patch = hf_config .max_dynamic_patch
182
217
use_thumbnail = hf_config .use_thumbnail
183
- max_dynamic_patch = hf_config .max_dynamic_patch
184
- if use_thumbnail :
218
+ if use_thumbnail and max_dynamic_patch > 1 :
185
219
max_dynamic_patch += 1
186
- downsample_ratio = hf_config .downsample_ratio
187
220
188
- image_size = vision_config .image_size
189
- patch_size = vision_config .patch_size
190
- num_patches = get_internvl_num_patches (image_size , patch_size ,
191
- downsample_ratio )
221
+ num_patches = get_internvl_num_patches (hf_config )
192
222
return num_patches * max_dynamic_patch
193
223
194
224
195
- def input_processor_for_internvl (ctx : InputContext , llm_inputs : LLMInputs ):
225
+ def get_max_internvl_image_size (ctx : InputContext ,
226
+ * ,
227
+ max_dynamic_patch : Optional [int ] = None ):
228
+ hf_config = ctx .get_hf_config ()
229
+ image_size = hf_config .vision_config .image_size
230
+
231
+ if max_dynamic_patch is None :
232
+ max_dynamic_patch = hf_config .max_dynamic_patch
233
+ use_thumbnail = hf_config .use_thumbnail
234
+ if use_thumbnail and max_dynamic_patch > 1 :
235
+ max_dynamic_patch += 1
236
+ width = image_size * max_dynamic_patch
237
+ height = image_size
238
+ return width , height
239
+
240
+
241
+ def input_processor_for_internvl (ctx : InputContext ,
242
+ llm_inputs : LLMInputs ,
243
+ * ,
244
+ max_dynamic_patch : Optional [int ] = None ):
196
245
multi_modal_data = llm_inputs .get ("multi_modal_data" )
197
246
if multi_modal_data is None or "image" not in multi_modal_data :
198
247
return llm_inputs
199
248
200
249
model_config = ctx .model_config
201
250
hf_config = ctx .get_hf_config ()
202
- vision_config = hf_config .vision_config
203
-
204
- image_size = vision_config .image_size
205
- patch_size = vision_config .patch_size
206
- downsample_ratio = hf_config .downsample_ratio
207
- num_patches = get_internvl_num_patches (image_size , patch_size ,
208
- downsample_ratio )
209
251
210
252
image_data = multi_modal_data ["image" ]
211
- min_num = hf_config . min_dynamic_patch
212
- max_num = hf_config . max_dynamic_patch
213
- use_thumbnail = hf_config . use_thumbnail
253
+ num_patches = get_internvl_num_patches ( hf_config )
254
+ num_blocks_calculator = calculate_num_blocks_wrapper (
255
+ hf_config , max_dynamic_patch )
214
256
if isinstance (image_data , Image .Image ):
215
257
width , height = image_data .size
216
- num_blocks , _ , _ = calculate_num_blocks (width , height , min_num ,
217
- max_num , image_size ,
218
- use_thumbnail )
258
+ num_blocks , _ , _ = num_blocks_calculator (width , height )
219
259
image_feature_size = [num_blocks * num_patches ]
220
260
elif is_list_of (image_data , Image .Image ):
221
261
image_feature_size = []
222
262
for image in image_data :
223
263
width , height = image .size
224
- num_blocks , _ , _ = calculate_num_blocks (width , height , min_num ,
225
- max_num , image_size ,
226
- use_thumbnail )
264
+ num_blocks , _ , _ = num_blocks_calculator (width , height )
227
265
image_feature_size .append (num_blocks * num_patches )
228
266
elif isinstance (image_data , torch .Tensor ):
229
267
num_images , image_feature_size , hidden_size = image_data .shape
@@ -253,31 +291,21 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
253
291
multi_modal_data = multi_modal_data )
254
292
255
293
256
- def input_mapper_for_internvl (ctx : InputContext , data : object ):
294
+ def input_mapper_for_internvl (ctx : InputContext ,
295
+ data : object ,
296
+ * ,
297
+ max_dynamic_patch : Optional [int ] = None ):
257
298
hf_config = ctx .get_hf_config ()
258
299
259
- use_thumbnail = hf_config .use_thumbnail
260
- min_num = hf_config .min_dynamic_patch
261
- max_num = hf_config .max_dynamic_patch
262
- image_size = hf_config .vision_config .image_size
263
-
300
+ image_pixel_values_mapper = image_to_pixel_values_wrapper (
301
+ hf_config , max_dynamic_patch )
264
302
if isinstance (data , Image .Image ):
265
- data = image_to_pixel_values (data ,
266
- image_size ,
267
- min_num ,
268
- max_num ,
269
- use_thumbnail = use_thumbnail )
303
+ data = image_pixel_values_mapper (data )
270
304
# Add an N dimension for number of images per prompt (currently 1).
271
305
data = data .unsqueeze (0 )
272
306
elif is_list_of (data , Image .Image ):
273
307
# we can't stack here because the images may have different num_patches
274
- data = [
275
- image_to_pixel_values (img ,
276
- image_size ,
277
- min_num ,
278
- max_num ,
279
- use_thumbnail = use_thumbnail ) for img in data
280
- ]
308
+ data = [image_pixel_values_mapper (img ) for img in data ]
281
309
model_config = ctx .model_config
282
310
tokenizer = cached_get_tokenizer (
283
311
model_config .tokenizer ,
@@ -292,35 +320,36 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
292
320
})
293
321
294
322
295
- def dummy_data_for_internvl (ctx : InputContext , seq_len : int ,
296
- mm_counts : Mapping [str , int ]):
323
+ def dummy_data_for_internvl (ctx : InputContext ,
324
+ seq_len : int ,
325
+ mm_counts : Mapping [str , int ],
326
+ * ,
327
+ max_dynamic_patch : Optional [int ] = None ):
297
328
num_images = mm_counts ["image" ]
298
329
299
- image_feature_size = get_max_internvl_image_tokens (ctx )
300
- model_config = ctx .model_config
301
330
hf_config = ctx .get_hf_config ()
302
- vision_config = hf_config .vision_config
331
+
332
+ image_feature_size = get_max_internvl_image_tokens (
333
+ ctx , max_dynamic_patch = max_dynamic_patch )
334
+ model_config = ctx .model_config
303
335
tokenizer = cached_get_tokenizer (
304
336
model_config .tokenizer ,
305
337
trust_remote_code = model_config .trust_remote_code )
306
338
307
339
seq_data = dummy_seq_data_for_clip (
308
- vision_config ,
340
+ hf_config . vision_config ,
309
341
seq_len ,
310
342
num_images ,
311
343
image_token_id = tokenizer .encode (IMG_CONTEXT ,
312
344
add_special_tokens = False )[0 ],
313
345
image_feature_size_override = image_feature_size ,
314
346
)
315
347
316
- image_size = vision_config .image_size
317
- min_num = hf_config .min_dynamic_patch
318
- max_num = hf_config .max_dynamic_patch
319
- max_image_width = max_num * image_size
320
- max_image_height = min_num * image_size
348
+ max_image_width , max_image_height = get_max_internvl_image_size (
349
+ ctx , max_dynamic_patch = max_dynamic_patch )
321
350
322
351
mm_data = dummy_image_for_clip (
323
- vision_config ,
352
+ hf_config . vision_config ,
324
353
num_images ,
325
354
image_width_override = max_image_width ,
326
355
image_height_override = max_image_height ,
@@ -470,7 +499,6 @@ def _process_image_input(
470
499
self ,
471
500
image_input : InternVLImageInputs ,
472
501
) -> torch .Tensor :
473
-
474
502
if image_input ["type" ] == "image_embeds" :
475
503
return image_input ["data" ]
476
504
0 commit comments