Skip to content

Llama-3.2-11B-Vision-Instruct vocab size vs. lm head mismatch #33819

@harshil-shah

Description

@harshil-shah

System Info

  • transformers version: 4.45.1
  • Platform: Linux-6.8.0-1013-gcp-x86_64-with-glibc2.39
  • Python version: 3.11.9
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.2
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    • distributed_type: FSDP
    • mixed_precision: bf16
    • use_cpu: False
    • debug: False
    • num_processes: 8
    • machine_rank: 0
    • num_machines: 1
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: False
    • fsdp_config: {'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP', 'fsdp_backward_prefetch': 'BACKWARD_PRE', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_forward_prefetch': False, 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 'HYBRID_SHARD', 'fsdp_state_dict_type': 'FULL_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_transformer_layer_cls_to_wrap': 'MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer', 'fsdp_use_orig_params': True}
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • PyTorch version (GPU?): 2.2.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi,

It seems there is a mismatch between the vocab size in the MllamaProcessor and the size of the lm_head weight matrix. Trying to call resize_token_embeddings doesn't fix this. This means that it is not possible to do training. Minimal example:

import requests
from PIL import Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor

MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"

processor = MllamaProcessor.from_pretrained(MODEL_NAME)
model = MllamaForConditionalGeneration.from_pretrained(MODEL_NAME)

print(f"{len(processor.tokenizer) = }")
print(f"Before resize: {model.language_model.lm_head.weight.shape = }")

model.resize_token_embeddings(len(processor.tokenizer))

print(f"After resize: {model.language_model.lm_head.weight.shape = }")

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
image = Image.open(requests.get(url, stream=True).raw)

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}
    ]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to(model.device)

output = model(**inputs, labels=inputs.input_ids)

This outputs:

len(processor.tokenizer) = 128257
Before resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])
After resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])

And then errors with:

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:2188, in MllamaForConditionalGeneration.forward(self, input_ids, pixel_values, aspect_ratio_mask, aspect_ratio_ids, attention_mask, cross_attention_mask, cross_attention_states, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   2185     cross_attention_mask = cross_attention_mask[:, :, cache_position]
   2186     full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-> 2188 outputs = self.language_model(
   2189     input_ids=input_ids,
   2190     attention_mask=attention_mask,
   2191     position_ids=position_ids,
   2192     cross_attention_states=cross_attention_states,
   2193     cross_attention_mask=cross_attention_mask,
   2194     full_text_row_masked_out_mask=full_text_row_masked_out_mask,
   2195     past_key_values=past_key_values,
   2196     use_cache=use_cache,
   2197     inputs_embeds=inputs_embeds,
   2198     labels=labels,
   2199     output_hidden_states=output_hidden_states,
   2200     output_attentions=output_attentions,
   2201     return_dict=return_dict,
   2202     cache_position=cache_position,
   2203     num_logits_to_keep=num_logits_to_keep,
   2204 )
   2206 return outputs

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:1961, in MllamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   1959     # Enable model parallelism
   1960     shift_labels = shift_labels.to(shift_logits.device)
-> 1961     loss = loss_fct(shift_logits, shift_labels)
   1963 if not return_dict:
   1964     output = (logits,) + outputs[1:]

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py:1179, in CrossEntropyLoss.forward(self, input, target)
   1178 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1179     return F.cross_entropy(input, target, weight=self.weight,
   1180                            ignore_index=self.ignore_index, reduction=self.reduction,
   1181                            label_smoothing=self.label_smoothing)

File ~/.venv/lib/python3.11/site-packages/torch/nn/functional.py:3059, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3057 if size_average is not None or reduce is not None:
   3058     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3059 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Target 128256 is out of bounds.

Expected behavior

The vocab size of the processor and model should match.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions