Skip to content

Commit 673d30b

Browse files
authored
Chameleon: minor fixes after shipping (huggingface#32037)
* fix merging * make chameleon conditional
1 parent 765732e commit 673d30b

File tree

7 files changed

+38
-31
lines changed

7 files changed

+38
-31
lines changed

docs/source/en/model_doc/chameleon.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ The original code can be found [here](https://github.com/facebookresearch/chamel
6464
Here's how to load the model and perform inference in half-precision (`torch.float16`):
6565

6666
```python
67-
from transformers import ChameleonProcessor, ChameleonForCausalLM
67+
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
6868
import torch
6969
from PIL import Image
7070
import requests
7171

7272
processor = ChameleonProcessor.from_pretrained("meta-chameleon")
73-
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
73+
model = ChameleonForConditionalGeneration.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
7474

7575
# prepare image and text prompt
7676
url = "https://bjiujitsu.com/wp-content/uploads/2021/01/jiu_jitsu_belt_white_1.jpg"
@@ -89,13 +89,13 @@ print(processor.decode(output[0], skip_special_tokens=True))
8989
Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:
9090

9191
```python
92-
from transformers import ChameleonProcessor, ChameleonForCausalLM
92+
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
9393
import torch
9494
from PIL import Image
9595
import requests
9696

9797
processor = ChameleonProcessor.from_pretrained("meta-chameleon")
98-
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
98+
model = ChameleonForConditionalGeneration.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
9999

100100
# Get three different images
101101
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
@@ -129,7 +129,7 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza
129129
The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
130130

131131
```python
132-
from transformers import ChameleonForCausalLM, BitsAndBytesConfig
132+
from transformers import ChameleonForConditionalGeneration, BitsAndBytesConfig
133133

134134
# specify how to quantize the model
135135
quantization_config = BitsAndBytesConfig(
@@ -138,17 +138,17 @@ quantization_config = BitsAndBytesConfig(
138138
bnb_4bit_compute_dtype=torch.float16,
139139
)
140140

141-
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto")
141+
model = ChameleonForConditionalGeneration.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto")
142142
```
143143

144144
### Use Flash-Attention 2 and SDPA to further speed-up generation
145145

146146
The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:
147147

148148
```python
149-
from transformers import ChameleonForCausalLM
149+
from transformers import ChameleonForConditionalGeneration
150150

151-
model = ChameleonForCausalLM.from_pretrained(
151+
model = ChameleonForConditionalGeneration.from_pretrained(
152152
model_id,
153153
torch_dtype=torch.float16,
154154
low_cpu_mem_usage=True,
@@ -183,7 +183,7 @@ model = ChameleonForCausalLM.from_pretrained(
183183
[[autodoc]] ChameleonModel
184184
- forward
185185

186-
## ChameleonForCausalLM
186+
## ChameleonForConditionalGeneration
187187

188-
[[autodoc]] ChameleonForCausalLM
188+
[[autodoc]] ChameleonForConditionalGeneration
189189
- forward

src/transformers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,7 @@
16161616
)
16171617
_import_structure["models.chameleon"].extend(
16181618
[
1619-
"ChameleonForCausalLM",
1619+
"ChameleonForConditionalGeneration",
16201620
"ChameleonModel",
16211621
"ChameleonPreTrainedModel",
16221622
"ChameleonProcessor",
@@ -6276,7 +6276,7 @@
62766276
load_tf_weights_in_canine,
62776277
)
62786278
from .models.chameleon import (
6279-
ChameleonForCausalLM,
6279+
ChameleonForConditionalGeneration,
62806280
ChameleonModel,
62816281
ChameleonPreTrainedModel,
62826282
ChameleonProcessor,

src/transformers/models/auto/modeling_auto.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@
446446
("blenderbot-small", "BlenderbotSmallForCausalLM"),
447447
("bloom", "BloomForCausalLM"),
448448
("camembert", "CamembertForCausalLM"),
449-
("chameleon", "ChameleonForCausalLM"),
450449
("code_llama", "LlamaForCausalLM"),
451450
("codegen", "CodeGenForCausalLM"),
452451
("cohere", "CohereForCausalLM"),
@@ -703,6 +702,7 @@
703702
[
704703
("blip", "BlipForConditionalGeneration"),
705704
("blip-2", "Blip2ForConditionalGeneration"),
705+
("chameleon", "ChameleonForConditionalGeneration"),
706706
("git", "GitForCausalLM"),
707707
("idefics2", "Idefics2ForConditionalGeneration"),
708708
("instructblip", "InstructBlipForConditionalGeneration"),

src/transformers/models/chameleon/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
pass
3737
else:
3838
_import_structure["modeling_chameleon"] = [
39-
"ChameleonForCausalLM",
39+
"ChameleonForConditionalGeneration",
4040
"ChameleonModel",
4141
"ChameleonPreTrainedModel",
4242
"ChameleonVQVAE",
@@ -62,7 +62,7 @@
6262
pass
6363
else:
6464
from .modeling_chameleon import (
65-
ChameleonForCausalLM,
65+
ChameleonForConditionalGeneration,
6666
ChameleonModel,
6767
ChameleonPreTrainedModel,
6868
ChameleonVQVAE,

src/transformers/models/chameleon/modeling_chameleon.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,8 @@ def forward(
12791279
if pixel_values is not None:
12801280
image_tokens = self.get_image_tokens(pixel_values)
12811281
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
1282-
input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.device, input_ids.dtype)
1282+
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
1283+
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
12831284

12841285
if inputs_embeds is None:
12851286
inputs_embeds = self.embed_tokens(input_ids)
@@ -1445,7 +1446,7 @@ def _update_causal_mask(
14451446
"Chameleon Model with a head on top used for outputting logits for next token prediction.",
14461447
CHAMELEON_START_DOCSTRING,
14471448
)
1448-
class ChameleonForCausalLM(ChameleonPreTrainedModel):
1449+
class ChameleonForConditionalGeneration(ChameleonPreTrainedModel):
14491450
_tied_weights_keys = ["lm_head.weight"]
14501451

14511452
def __init__(self, config):
@@ -1504,12 +1505,12 @@ def forward(
15041505
Example:
15051506
15061507
```python
1507-
>>> from transformers import ChameleonProcessor, ChameleonForCausalLM
1508+
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
15081509
>>> import torch
15091510
>>> import requests
15101511
>>> from PIL import Image
15111512
1512-
>>> model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
1513+
>>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
15131514
>>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
15141515
15151516
>>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."

src/transformers/utils/dummy_pt_objects.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1835,7 +1835,7 @@ def load_tf_weights_in_canine(*args, **kwargs):
18351835
requires_backends(load_tf_weights_in_canine, ["torch"])
18361836

18371837

1838-
class ChameleonForCausalLM(metaclass=DummyObject):
1838+
class ChameleonForConditionalGeneration(metaclass=DummyObject):
18391839
_backends = ["torch"]
18401840

18411841
def __init__(self, *args, **kwargs):

tests/models/chameleon/test_modeling_chameleon.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import torch
4545

4646
from transformers import (
47-
ChameleonForCausalLM,
47+
ChameleonForConditionalGeneration,
4848
ChameleonModel,
4949
ChameleonProcessor,
5050
)
@@ -191,7 +191,7 @@ def create_and_check_for_causal_lm(
191191
encoder_hidden_states,
192192
encoder_attention_mask,
193193
):
194-
model = ChameleonForCausalLM(config=config)
194+
model = ChameleonForConditionalGeneration(config=config)
195195
model.to(torch_device)
196196
model.eval()
197197
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
@@ -209,7 +209,7 @@ def create_and_check_decoder_model_past_large_inputs(
209209
encoder_attention_mask,
210210
):
211211
config.is_decoder = True
212-
model = ChameleonForCausalLM(config=config)
212+
model = ChameleonForConditionalGeneration(config=config)
213213
model.to(torch_device)
214214
model.eval()
215215

@@ -273,12 +273,12 @@ def prepare_config_and_inputs_for_common(self):
273273

274274
@require_torch
275275
class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
276-
all_model_classes = (ChameleonModel, ChameleonForCausalLM) if is_torch_available() else ()
277-
all_generative_model_classes = (ChameleonForCausalLM,) if is_torch_available() else ()
276+
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
277+
all_generative_model_classes = (ChameleonForConditionalGeneration,) if is_torch_available() else ()
278278
pipeline_model_mapping = (
279279
{
280280
"feature-extraction": ChameleonModel,
281-
"text-generation": ChameleonForCausalLM,
281+
"text-generation": ChameleonForConditionalGeneration,
282282
}
283283
if is_torch_available()
284284
else {}
@@ -339,7 +339,7 @@ def test_flash_attn_2_generate_padding_right(self):
339339
"""
340340
Overwritting the common test as the test is flaky on tiny models
341341
"""
342-
model = ChameleonForCausalLM.from_pretrained(
342+
model = ChameleonForConditionalGeneration.from_pretrained(
343343
"facebook/chameleon-7b",
344344
load_in_4bit=True,
345345
device_map={"": 0},
@@ -355,7 +355,7 @@ def test_flash_attn_2_generate_padding_right(self):
355355
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
356356
output_native = processor.tokenizer.batch_decode(output_native)
357357

358-
model = ChameleonForCausalLM.from_pretrained(
358+
model = ChameleonForConditionalGeneration.from_pretrained(
359359
"facebook/chameleon-7b",
360360
load_in_4bit=True,
361361
attn_implementation="flash_attention_2",
@@ -377,7 +377,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
377377
@require_bitsandbytes
378378
@require_read_token
379379
def test_model_7b(self):
380-
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
380+
model = ChameleonForConditionalGeneration.from_pretrained(
381+
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
382+
)
381383
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
382384

383385
image = Image.open(
@@ -397,7 +399,9 @@ def test_model_7b(self):
397399
@require_bitsandbytes
398400
@require_read_token
399401
def test_model_7b_batched(self):
400-
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
402+
model = ChameleonForConditionalGeneration.from_pretrained(
403+
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
404+
)
401405
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
402406

403407
image = Image.open(
@@ -428,7 +432,9 @@ def test_model_7b_batched(self):
428432
@require_bitsandbytes
429433
@require_read_token
430434
def test_model_7b_multi_image(self):
431-
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
435+
model = ChameleonForConditionalGeneration.from_pretrained(
436+
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
437+
)
432438
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
433439

434440
image = Image.open(

0 commit comments

Comments
 (0)