Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions configs/sparsification/methods/DivPrune/divprune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
base:
seed: &seed 42
model:
type: Llava
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, transformed]
type: vqa
name: [mme]
download: False
path: MME dataset path
bs: 1
inference_per_block: False
sparse:
method: TokenReduction
special:
method: DivPrune
reduction_ratio: 0.9444 # 0.7778 0.8889 0.9444
save:
save_trans: False
save_fake: False
save_path: /path/to/save/
25 changes: 25 additions & 0 deletions configs/sparsification/methods/MustDrop/mustdrop.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
base:
seed: &seed 42
model:
type: Llava
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, transformed]
type: vqa
name: [mme]
download: False
path: MME dataset path
bs: 1
sparse:
vision:
method: TokenReduction
special:
method: MustDrop
spatial_threshold: 0.6
window_size: [3, 3]
retained_tokens: 128 # llava_next: 128, 64, 32 llava: 192, 128, 64
save:
save_trans: False
save_fake: False
save_path: /path/to/save/
86 changes: 31 additions & 55 deletions llmc/compression/token_reduction/divprune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
from functools import wraps
from types import MethodType

Expand All @@ -7,7 +6,6 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper


def pairwise_cosine_similarity(matrix):
Expand All @@ -22,7 +20,7 @@ def divprune(
cosine_matrix=None,
threshold_ratio=0.1,
):
threshold_terms = int(round(threshold_ratio * image_feature_length))
threshold_terms = round(threshold_ratio * image_feature_length)
if cosine_matrix is None:
cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors))

Expand Down Expand Up @@ -53,22 +51,16 @@ def divprune(
return s, cosine_matrix


def divprune_post_hook(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
pruning_paras=None,
):
rate = pruning_paras['rate']
SYS_TOKEN_LEN = pruning_paras['image_token_start_index']
img_feature_len = pruning_paras['image_token_length']
def divprune_post_hook(*args, pruning_paras=None):
args = list(args)
position_ids, attention_mask, inputs_embeds = args[1], args[2], args[4]
rate = pruning_paras['reduction_ratio']
SYS_TOKEN_LEN = pruning_paras['vision_token_start_index']
img_feature_len = pruning_paras['vision_token_length']
device = inputs_embeds.device
visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len]
selected_visual_tokens, cosine_matrix = divprune(
visual_tokens, img_feature_len, None, threshold_ratio=rate
visual_tokens, img_feature_len, None, threshold_ratio=1 - rate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The threshold ratio is being subtracted from 1. Ensure this is the intended behavior, as it changes the meaning of the threshold_ratio parameter. If the intention is to use the reduction ratio directly, this subtraction is unnecessary and could lead to confusion.

)

selected_visual_tokens += SYS_TOKEN_LEN
Expand All @@ -83,20 +75,13 @@ def divprune_post_hook(
)
keep_indexs = keep_indexs.sort().values

inputs_embeds = inputs_embeds[:, keep_indexs]
if position_ids is not None:
position_ids = position_ids[:, keep_indexs, :]
args[1] = position_ids[:, keep_indexs, :]
if attention_mask is not None:
attention_mask = attention_mask[:, keep_indexs]

return (
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
)
args[2] = attention_mask[:, keep_indexs]
args[4] = inputs_embeds[:, keep_indexs]

return tuple(args)


@TOKEN_REDUCTION_REGISTRY.register('DivPrune')
Expand All @@ -107,43 +92,34 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):
self.special_config['image_token_length'] = self.model.pruning_config[
'image_token_length'
]

self.pruning_paras = self.special_config

def register_reduction_modules(self):

def input_hook_llava(fn, pruning_paras):
def input_hook_llava(fn, pruning_paras, llava_next):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
if args[0].shape[1] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
pruning_paras['image_token_start_index'] = torch.where(token_indices)[
0
][0].item()

outputs = fn(*args, **kwargs)

return divprune_post_hook(*outputs, pruning_paras=pruning_paras)

outs = fn(*args, **kwargs)

if llava_next:
message = (
'To obtain the vision_token_length for LLaVA-1.6, you should append '
'`image_features[0].shape[0]` to the return value of the function '
'`prepare_inputs_labels_for_multimodal`, and modify the related code.'
)
assert len(outs) == 7, message

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert to enforce an API contract on an external library is brittle. If the prepare_inputs_labels_for_multimodal function in the llava library changes its return signature, this will cause the application to crash. The assertion message is helpful for developers but doesn't prevent the crash.

A more robust approach would be to handle this case gracefully, for instance, by checking the length of outs and raising a more specific, informative exception or logging a clear error message.

pruning_paras['vision_token_length'] = outs[-1]
return divprune_post_hook(*outs, pruning_paras=pruning_paras)
return wrapper

if self.model.__class__.__name__ == 'Llava':
from llava.constants import IMAGE_TOKEN_INDEX

hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras,
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras,
llava_next=self.special_config['vision_token_length'] is None
), self.model.vlm_model
)
69 changes: 49 additions & 20 deletions llmc/compression/token_reduction/mustdrop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import functools
import math
from types import MethodType
from typing import Callable, Tuple

import torch
import torch.nn.functional as F
from einops import rearrange

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prepare_inputs_labels_for_multimodal_with_index_masks


@TOKEN_REDUCTION_REGISTRY.register('MustDrop')
Expand All @@ -15,18 +21,11 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):
self.pruning_loc = self.special_config['pruning_loc']
self.pruning_loc = self.model.pruning_config.get('select_layer', -1)
self.pruning_paras = self.special_config

def register_reduction_modules(self):

import math
from typing import Callable, Tuple

import numpy as np
import torch.nn.functional as F
from einops import rearrange

def conditional_pooling(
feat: torch.Tensor,
threshold: float,
Expand Down Expand Up @@ -170,7 +169,14 @@ def merge(x: torch.Tensor, mode='mean') -> torch.Tensor:
)
x = torch.cat([dst, unm], dim=1)
x = torch.cat((x_cls, x), dim=1)
return x

index_masks = torch.zeros((n, t1), dtype=torch.bool, device=x_feat.device)
dst_flat = dst_idx.view(n, -1)
unm_flat = unm_idx.view(n, -1)
index_masks.scatter_(1, dst_flat, True)
index_masks.scatter_(1, unm_flat, True)

return x, index_masks

return merge

Expand All @@ -181,26 +187,49 @@ def merge_wavg(
if size is None:
size = torch.ones_like(x[..., 0, None])

x = merge(x * size, mode='sum')
size = merge(size, mode='sum')
x, index_masks = merge(x * size, mode='sum')
size, _ = merge(size, mode='sum')
x = x / size

return x, size
return x, size, index_masks

def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras):
def spatial_merge_hook(module, inps, outs, pruning_paras, llava_next):
spatial_threshold = pruning_paras['spatial_threshold']
window_size = pruning_paras['window_size']
hidden_states = layer_outs[0]
hidden_states = outs[0]
vtoken_length = hidden_states.shape[1]
fix_r = 0
if pruning_paras.get('retained_tokens', None) is not None:
retained_tokens = pruning_paras['retained_tokens']
fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \
fix_r = (vtoken_length - retained_tokens) \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation for fix_r can result in a negative value if retained_tokens is greater than vtoken_length. While this might not happen with current configurations, it could lead to unexpected behavior in the conditional_pooling function if inputs change. It would be safer to ensure fix_r is non-negative.

Suggested change
fix_r = (vtoken_length - retained_tokens) \
fix_r = max(0, (vtoken_length - retained_tokens) // \
(window_size[0] * window_size[1] - 1))

// (window_size[0] * window_size[1] - 1)
merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r)
hidden_states, size = merge_wavg(merge, hidden_states, None)
return (hidden_states,)
hidden_states, size, index_masks = merge_wavg(merge, hidden_states, None)

if not llava_next:
return (hidden_states,)

self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras),
with_kwargs=True,
pruning_paras['index_masks'] = index_masks
return outs

def update_index_masks_hook(module, inps, outs, pruning_paras):
module.index_masks = pruning_paras['index_masks']

self.blocks[self.pruning_loc].register_forward_hook(
functools.partial(
spatial_merge_hook,
pruning_paras=self.pruning_paras,
llava_next=self.special_config['vision_token_length'] is None
),
)

if self.special_config['vision_token_length'] is None:

self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
prepare_inputs_labels_for_multimodal_with_index_masks,
self.model.vlm_model
)

self.model.vision_model.register_forward_hook(
functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras),
)
Loading