|
4 | 4 | # SPDX-License-Identifier: BSD-3-Clause |
5 | 5 | # |
6 | 6 | # ----------------------------------------------------------------------------- |
7 | | -from typing import List, Optional, Tuple, Union |
8 | 7 |
|
9 | 8 | import torch |
10 | 9 | import torch.utils.checkpoint |
11 | | -from torch import nn |
12 | 10 | from transformers.models.llava.modeling_llava import ( |
13 | | - LlavaCausalLMOutputWithPast, |
14 | 11 | LlavaForConditionalGeneration, |
15 | | - logger, |
16 | 12 | ) |
17 | 13 |
|
18 | 14 | BS = 1 |
|
23 | 19 |
|
24 | 20 |
|
25 | 21 | class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration): |
26 | | - def forward( |
27 | | - self, |
28 | | - input_ids: torch.LongTensor = None, |
29 | | - pixel_values: torch.FloatTensor = None, |
30 | | - attention_mask: Optional[torch.Tensor] = None, |
31 | | - position_ids: Optional[torch.LongTensor] = None, |
32 | | - past_key_values: Optional[List[torch.FloatTensor]] = None, |
33 | | - inputs_embeds: Optional[torch.FloatTensor] = None, |
34 | | - vision_feature_layer: Optional[int] = None, |
35 | | - vision_feature_select_strategy: Optional[str] = None, |
36 | | - labels: Optional[torch.LongTensor] = None, |
37 | | - use_cache: Optional[bool] = None, |
38 | | - output_attentions: Optional[bool] = None, |
39 | | - output_hidden_states: Optional[bool] = None, |
40 | | - return_dict: Optional[bool] = None, |
41 | | - cache_position: Optional[torch.LongTensor] = None, |
42 | | - num_logits_to_keep: int = 0, |
43 | | - ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: |
44 | | - r""" |
45 | | - Args: |
46 | | - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
47 | | - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
48 | | - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
49 | | - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
50 | | -
|
51 | | - num_logits_to_keep (`int`, *optional*): |
52 | | - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
53 | | - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
54 | | - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
55 | | -
|
56 | | -
|
57 | | - Returns: |
58 | | -
|
59 | | - Example: |
60 | | -
|
61 | | - ```python |
62 | | - >>> from PIL import Image |
63 | | - >>> import requests |
64 | | - >>> from transformers import AutoProcessor, LlavaForConditionalGeneration |
65 | | -
|
66 | | - >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") |
67 | | - >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
68 | | -
|
69 | | - >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:" |
70 | | - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
71 | | - >>> image = Image.open(requests.get(url, stream=True).raw) |
72 | | -
|
73 | | - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
74 | | -
|
75 | | - >>> # Generate |
76 | | - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) |
77 | | - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
78 | | - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" |
79 | | - ```""" |
80 | | - |
81 | | - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
82 | | - output_hidden_states = ( |
83 | | - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
84 | | - ) |
85 | | - return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
86 | | - vision_feature_layer = ( |
87 | | - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
88 | | - ) |
89 | | - vision_feature_select_strategy = ( |
90 | | - vision_feature_select_strategy |
91 | | - if vision_feature_select_strategy is not None |
92 | | - else self.config.vision_feature_select_strategy |
93 | | - ) |
94 | | - |
95 | | - if (input_ids is None) ^ (inputs_embeds is not None): |
96 | | - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
97 | | - |
98 | | - if pixel_values is not None and inputs_embeds is not None: |
99 | | - raise ValueError( |
100 | | - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" |
101 | | - ) |
102 | | - |
103 | | - legacy_processing = False |
104 | | - if inputs_embeds is None: |
105 | | - inputs_embeds = self.get_input_embeddings()(input_ids) |
106 | | - |
107 | | - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing |
108 | | - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt |
109 | | - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True |
110 | | - legacy_processing = ( |
111 | | - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length |
112 | | - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) |
113 | | - |
114 | | - if pixel_values is not None: |
115 | | - image_features = self.get_image_features( |
116 | | - pixel_values=pixel_values, |
117 | | - vision_feature_layer=vision_feature_layer, |
118 | | - vision_feature_select_strategy=vision_feature_select_strategy, |
119 | | - ) |
120 | | - |
121 | | - if legacy_processing: |
122 | | - logger.warning_once( |
123 | | - "Expanding inputs for image tokens in LLaVa should be done in processing. " |
124 | | - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " |
125 | | - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " |
126 | | - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." |
127 | | - ) |
128 | | - # prefill stage vs decoding stage (legacy behavior copied) |
129 | | - if input_ids.shape[1] != 1: |
130 | | - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( |
131 | | - image_features, inputs_embeds, input_ids, attention_mask, labels |
132 | | - ) |
133 | | - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) |
134 | | - else: |
135 | | - # Retrieve the first layer to inspect the logits and mask out the hidden states |
136 | | - # that are set to 0 |
137 | | - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
138 | | - |
139 | | - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 |
140 | | - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) |
141 | | - |
142 | | - # Get the target length |
143 | | - target_length = input_ids.shape[1] |
144 | | - past_length = first_layer_past_key_value.shape[-1] |
145 | | - |
146 | | - extended_attention_mask = torch.ones( |
147 | | - (attention_mask.shape[0], past_length), |
148 | | - dtype=attention_mask.dtype, |
149 | | - device=attention_mask.device, |
150 | | - ) |
151 | | - |
152 | | - # Filter out only the tokens that can be un-attended, this can happen |
153 | | - # if one uses Llava + Fused modules where the cache on the |
154 | | - # first iteration is already big enough, or if one passes custom cache |
155 | | - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) |
156 | | - new_batch_index = batch_index[valid_indices] |
157 | | - new_non_attended_tokens = non_attended_tokens[valid_indices] |
158 | | - |
159 | | - # Zero-out the places where we don't need to attend |
160 | | - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 |
161 | | - |
162 | | - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) |
163 | | - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
164 | | - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ |
165 | | - -target_length: |
166 | | - ] |
167 | | - |
168 | | - # TODO: @raushan retain only the new behavior after v4.47 |
169 | | - else: |
170 | | - n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() |
171 | | - n_image_features = image_features.shape[1] |
172 | | - if n_image_tokens != n_image_features: |
173 | | - raise ValueError( |
174 | | - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
175 | | - ) |
176 | | - |
177 | | - mask = input_ids == self.config.image_token_index |
178 | | - indices1 = mask.to(torch.int64).cumsum(1) - 1 |
179 | | - indices0 = torch.arange(mask.shape[0]).view(-1, 1) |
180 | | - image_features_expanded = image_features[indices0, indices1] |
181 | | - image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) |
182 | | - # *where to skip image encoder for decode* |
183 | | - inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) |
184 | | - |
| 22 | + def forward(self, input_ids, position_ids, pixel_values, past_key_values): |
| 23 | + inputs_embeds = self.get_input_embeddings()(input_ids) |
| 24 | + # Image features |
| 25 | + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) |
| 26 | + selected_image_feature = image_outputs.hidden_states[self.config.vision_feature_layer] |
| 27 | + vision_feature_select_strategy = self.config.vision_feature_select_strategy |
| 28 | + if vision_feature_select_strategy == "default": |
| 29 | + selected_image_feature = selected_image_feature[:, 1:] |
| 30 | + elif vision_feature_select_strategy == "full": |
| 31 | + selected_image_feature = selected_image_feature |
| 32 | + else: |
| 33 | + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") |
| 34 | + image_features = self.multi_modal_projector(selected_image_feature) |
| 35 | + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| 36 | + |
| 37 | + mask = input_ids == self.config.image_token_index |
| 38 | + indices1 = mask.to(torch.int64).cumsum(1) - 1 |
| 39 | + indices0 = torch.arange(mask.shape[0]).view(-1, 1) |
| 40 | + image_features_expanded = image_features[indices0, indices1] |
| 41 | + image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) |
| 42 | + # *where to skip image encoder for decode* |
| 43 | + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) |
185 | 44 | outputs = self.language_model( |
186 | | - attention_mask=attention_mask, |
| 45 | + inputs_embeds=inputs_embeds, |
187 | 46 | position_ids=position_ids, |
188 | 47 | past_key_values=past_key_values, |
189 | | - inputs_embeds=inputs_embeds, |
190 | | - use_cache=use_cache, |
191 | | - output_attentions=output_attentions, |
192 | | - output_hidden_states=output_hidden_states, |
193 | | - return_dict=return_dict, |
194 | | - cache_position=cache_position, |
195 | | - num_logits_to_keep=num_logits_to_keep, |
196 | 48 | ) |
197 | | - |
198 | | - logits = outputs[0] |
199 | | - |
200 | | - loss = None |
201 | | - if labels is not None: |
202 | | - # Shift so that tokens < n predict n |
203 | | - if attention_mask is not None: |
204 | | - # we use the input attention mask to shift the logits and labels, because it is 2D. |
205 | | - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft |
206 | | - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) |
207 | | - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() |
208 | | - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() |
209 | | - else: |
210 | | - shift_logits = logits[..., :-1, :].contiguous() |
211 | | - shift_labels = labels[..., 1:].contiguous() |
212 | | - # Flatten the tokens |
213 | | - loss_fct = nn.CrossEntropyLoss() |
214 | | - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)) |
215 | | - |
216 | | - if not return_dict: |
217 | | - output = (logits,) + outputs[1:] |
218 | | - return (loss,) + output if loss is not None else output |
219 | | - |
220 | | - return logits, pixel_values, outputs.past_key_values |
| 49 | + return outputs.logits, pixel_values, outputs.past_key_values |
221 | 50 |
|
222 | 51 | def get_dummy_inputs(self, **kwargs): |
223 | 52 | num_layers = self.config.text_config.num_hidden_layers |
|
0 commit comments