-
Notifications
You must be signed in to change notification settings - Fork 67
update divprune,mustdrop for llava-next #428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| base: | ||
| seed: &seed 42 | ||
| model: | ||
| type: Llava | ||
| path: model path | ||
| torch_dtype: auto | ||
| eval: | ||
| eval_pos: [pretrain, transformed] | ||
| type: vqa | ||
| name: [mme] | ||
| download: False | ||
| path: MME dataset path | ||
| bs: 1 | ||
| inference_per_block: False | ||
| sparse: | ||
| method: TokenReduction | ||
| special: | ||
| method: DivPrune | ||
| reduction_ratio: 0.9444 # 0.7778 0.8889 0.9444 | ||
| save: | ||
| save_trans: False | ||
| save_fake: False | ||
| save_path: /path/to/save/ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| base: | ||
| seed: &seed 42 | ||
| model: | ||
| type: Llava | ||
| path: model path | ||
| torch_dtype: auto | ||
| eval: | ||
| eval_pos: [pretrain, transformed] | ||
| type: vqa | ||
| name: [mme] | ||
| download: False | ||
| path: MME dataset path | ||
| bs: 1 | ||
| sparse: | ||
| vision: | ||
| method: TokenReduction | ||
| special: | ||
| method: MustDrop | ||
| spatial_threshold: 0.6 | ||
| window_size: [3, 3] | ||
| retained_tokens: 128 # llava_next: 128, 64, 32 llava: 192, 128, 64 | ||
| save: | ||
| save_trans: False | ||
| save_fake: False | ||
| save_path: /path/to/save/ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,3 @@ | ||
| import functools | ||
| from functools import wraps | ||
| from types import MethodType | ||
|
|
||
|
|
@@ -7,7 +6,6 @@ | |
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||
|
|
||
| from .token_reduction_module import TokenReductionModule | ||
| from .utils import prefill_wrapper | ||
|
|
||
|
|
||
| def pairwise_cosine_similarity(matrix): | ||
|
|
@@ -22,7 +20,7 @@ def divprune( | |
| cosine_matrix=None, | ||
| threshold_ratio=0.1, | ||
| ): | ||
| threshold_terms = int(round(threshold_ratio * image_feature_length)) | ||
| threshold_terms = round(threshold_ratio * image_feature_length) | ||
| if cosine_matrix is None: | ||
| cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors)) | ||
|
|
||
|
|
@@ -53,22 +51,16 @@ def divprune( | |
| return s, cosine_matrix | ||
|
|
||
|
|
||
| def divprune_post_hook( | ||
| input_ids, | ||
| position_ids, | ||
| attention_mask, | ||
| past_key_values, | ||
| inputs_embeds, | ||
| labels, | ||
| pruning_paras=None, | ||
| ): | ||
| rate = pruning_paras['rate'] | ||
| SYS_TOKEN_LEN = pruning_paras['image_token_start_index'] | ||
| img_feature_len = pruning_paras['image_token_length'] | ||
| def divprune_post_hook(*args, pruning_paras=None): | ||
| args = list(args) | ||
| position_ids, attention_mask, inputs_embeds = args[1], args[2], args[4] | ||
| rate = pruning_paras['reduction_ratio'] | ||
| SYS_TOKEN_LEN = pruning_paras['vision_token_start_index'] | ||
| img_feature_len = pruning_paras['vision_token_length'] | ||
| device = inputs_embeds.device | ||
| visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] | ||
| selected_visual_tokens, cosine_matrix = divprune( | ||
| visual_tokens, img_feature_len, None, threshold_ratio=rate | ||
| visual_tokens, img_feature_len, None, threshold_ratio=1 - rate | ||
| ) | ||
|
|
||
| selected_visual_tokens += SYS_TOKEN_LEN | ||
|
|
@@ -83,20 +75,13 @@ def divprune_post_hook( | |
| ) | ||
| keep_indexs = keep_indexs.sort().values | ||
|
|
||
| inputs_embeds = inputs_embeds[:, keep_indexs] | ||
| if position_ids is not None: | ||
| position_ids = position_ids[:, keep_indexs, :] | ||
| args[1] = position_ids[:, keep_indexs, :] | ||
| if attention_mask is not None: | ||
| attention_mask = attention_mask[:, keep_indexs] | ||
|
|
||
| return ( | ||
| input_ids, | ||
| position_ids, | ||
| attention_mask, | ||
| past_key_values, | ||
| inputs_embeds, | ||
| labels, | ||
| ) | ||
| args[2] = attention_mask[:, keep_indexs] | ||
| args[4] = inputs_embeds[:, keep_indexs] | ||
|
|
||
| return tuple(args) | ||
|
|
||
|
|
||
| @TOKEN_REDUCTION_REGISTRY.register('DivPrune') | ||
|
|
@@ -107,43 +92,34 @@ def __init__(self, config, model, blocks): | |
| self.register_reduction_modules() | ||
|
|
||
| def add_sparse_config(self): | ||
| self.special_config['image_token_length'] = self.model.pruning_config[ | ||
| 'image_token_length' | ||
| ] | ||
|
|
||
| self.pruning_paras = self.special_config | ||
|
|
||
| def register_reduction_modules(self): | ||
|
|
||
| def input_hook_llava(fn, pruning_paras): | ||
| def input_hook_llava(fn, pruning_paras, llava_next): | ||
| @wraps(fn) | ||
| def wrapper(self, *args, **kwargs): | ||
| if len(args) == 0: | ||
| return fn(*args, **kwargs) | ||
| input_args = args[0] | ||
| if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: | ||
| if args[0].shape[1] == 1: | ||
| return fn(*args, **kwargs) | ||
|
|
||
| input_ids = args[0] | ||
| attention_mask = args[2] | ||
| token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[ | ||
| 0 | ||
| ][0].item() | ||
|
|
||
| outputs = fn(*args, **kwargs) | ||
|
|
||
| return divprune_post_hook(*outputs, pruning_paras=pruning_paras) | ||
|
|
||
| outs = fn(*args, **kwargs) | ||
|
|
||
| if llava_next: | ||
| message = ( | ||
| 'To obtain the vision_token_length for LLaVA-1.6, you should append ' | ||
| '`image_features[0].shape[0]` to the return value of the function ' | ||
| '`prepare_inputs_labels_for_multimodal`, and modify the related code.' | ||
| ) | ||
| assert len(outs) == 7, message | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using A more robust approach would be to handle this case gracefully, for instance, by checking the length of |
||
| pruning_paras['vision_token_length'] = outs[-1] | ||
| return divprune_post_hook(*outs, pruning_paras=pruning_paras) | ||
| return wrapper | ||
|
|
||
| if self.model.__class__.__name__ == 'Llava': | ||
| from llava.constants import IMAGE_TOKEN_INDEX | ||
|
|
||
| hook_fn = input_hook_llava( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||
| self.pruning_paras, | ||
| ) | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||
| hook_fn, self.model.vlm_model | ||
| input_hook_llava( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||
| self.pruning_paras, | ||
| llava_next=self.special_config['vision_token_length'] is None | ||
| ), self.model.vlm_model | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,10 +1,16 @@ | ||||||||
| import functools | ||||||||
| import math | ||||||||
| from types import MethodType | ||||||||
| from typing import Callable, Tuple | ||||||||
|
|
||||||||
| import torch | ||||||||
| import torch.nn.functional as F | ||||||||
| from einops import rearrange | ||||||||
|
|
||||||||
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||||||||
|
|
||||||||
| from .token_reduction_module import TokenReductionModule | ||||||||
| from .utils import prepare_inputs_labels_for_multimodal_with_index_masks | ||||||||
|
|
||||||||
|
|
||||||||
| @TOKEN_REDUCTION_REGISTRY.register('MustDrop') | ||||||||
|
|
@@ -15,18 +21,11 @@ def __init__(self, config, model, blocks): | |||||||
| self.register_reduction_modules() | ||||||||
|
|
||||||||
| def add_sparse_config(self): | ||||||||
| self.pruning_loc = self.special_config['pruning_loc'] | ||||||||
| self.pruning_loc = self.model.pruning_config.get('select_layer', -1) | ||||||||
| self.pruning_paras = self.special_config | ||||||||
|
|
||||||||
| def register_reduction_modules(self): | ||||||||
|
|
||||||||
| import math | ||||||||
| from typing import Callable, Tuple | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import torch.nn.functional as F | ||||||||
| from einops import rearrange | ||||||||
|
|
||||||||
| def conditional_pooling( | ||||||||
| feat: torch.Tensor, | ||||||||
| threshold: float, | ||||||||
|
|
@@ -170,7 +169,14 @@ def merge(x: torch.Tensor, mode='mean') -> torch.Tensor: | |||||||
| ) | ||||||||
| x = torch.cat([dst, unm], dim=1) | ||||||||
| x = torch.cat((x_cls, x), dim=1) | ||||||||
| return x | ||||||||
|
|
||||||||
| index_masks = torch.zeros((n, t1), dtype=torch.bool, device=x_feat.device) | ||||||||
| dst_flat = dst_idx.view(n, -1) | ||||||||
| unm_flat = unm_idx.view(n, -1) | ||||||||
| index_masks.scatter_(1, dst_flat, True) | ||||||||
| index_masks.scatter_(1, unm_flat, True) | ||||||||
|
|
||||||||
| return x, index_masks | ||||||||
|
|
||||||||
| return merge | ||||||||
|
|
||||||||
|
|
@@ -181,26 +187,49 @@ def merge_wavg( | |||||||
| if size is None: | ||||||||
| size = torch.ones_like(x[..., 0, None]) | ||||||||
|
|
||||||||
| x = merge(x * size, mode='sum') | ||||||||
| size = merge(size, mode='sum') | ||||||||
| x, index_masks = merge(x * size, mode='sum') | ||||||||
| size, _ = merge(size, mode='sum') | ||||||||
| x = x / size | ||||||||
|
|
||||||||
| return x, size | ||||||||
| return x, size, index_masks | ||||||||
|
|
||||||||
| def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras): | ||||||||
| def spatial_merge_hook(module, inps, outs, pruning_paras, llava_next): | ||||||||
| spatial_threshold = pruning_paras['spatial_threshold'] | ||||||||
| window_size = pruning_paras['window_size'] | ||||||||
| hidden_states = layer_outs[0] | ||||||||
| hidden_states = outs[0] | ||||||||
| vtoken_length = hidden_states.shape[1] | ||||||||
| fix_r = 0 | ||||||||
| if pruning_paras.get('retained_tokens', None) is not None: | ||||||||
| retained_tokens = pruning_paras['retained_tokens'] | ||||||||
| fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \ | ||||||||
| fix_r = (vtoken_length - retained_tokens) \ | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation for
Suggested change
|
||||||||
| // (window_size[0] * window_size[1] - 1) | ||||||||
| merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r) | ||||||||
| hidden_states, size = merge_wavg(merge, hidden_states, None) | ||||||||
| return (hidden_states,) | ||||||||
| hidden_states, size, index_masks = merge_wavg(merge, hidden_states, None) | ||||||||
|
|
||||||||
| if not llava_next: | ||||||||
| return (hidden_states,) | ||||||||
|
|
||||||||
| self.blocks[self.pruning_loc - 1].register_forward_hook( | ||||||||
| functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras), | ||||||||
| with_kwargs=True, | ||||||||
| pruning_paras['index_masks'] = index_masks | ||||||||
| return outs | ||||||||
|
|
||||||||
| def update_index_masks_hook(module, inps, outs, pruning_paras): | ||||||||
| module.index_masks = pruning_paras['index_masks'] | ||||||||
|
|
||||||||
| self.blocks[self.pruning_loc].register_forward_hook( | ||||||||
| functools.partial( | ||||||||
| spatial_merge_hook, | ||||||||
| pruning_paras=self.pruning_paras, | ||||||||
| llava_next=self.special_config['vision_token_length'] is None | ||||||||
| ), | ||||||||
| ) | ||||||||
|
|
||||||||
| if self.special_config['vision_token_length'] is None: | ||||||||
|
|
||||||||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||||||||
| prepare_inputs_labels_for_multimodal_with_index_masks, | ||||||||
| self.model.vlm_model | ||||||||
| ) | ||||||||
|
|
||||||||
| self.model.vision_model.register_forward_hook( | ||||||||
| functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras), | ||||||||
| ) | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The threshold ratio is being subtracted from 1. Ensure this is the intended behavior, as it changes the meaning of the
threshold_ratioparameter. If the intention is to use the reduction ratio directly, this subtraction is unnecessary and could lead to confusion.