Skip to content

Got shape error when running JetMoe #40817

@wtomin

Description

@wtomin

System Info

Environment

  • Linux, A100 GPU (CUDA 12.1, driver version 530.30.02)
  • python=3.10.0
  • transformers=4.50.0 (upgrade to 4.56.1 leads to the same error)

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run this code snippet with transformers and jetmoe/jetmoe-8b (I found this example from #40749 ):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b")
model = AutoModelForCausalLM.from_pretrained(
    "jetmoe/jetmoe-8b",
    device_map="auto",
    attn_implementation="sdpa"
)
input_ids = tokenizer("The stock market rallied today after positive economic news", return_tensors="pt").to(model.device)

output = model.generate(**input_ids, cache_implementation=None)
print(tokenizer.decode(output[0], skip_special_tokens=True))

Expected behavior

Error message

Traceback (most recent call last):
  File "/data2/g00523483/ddd/JetMoE/run_model.py", line 12, in <module>
    output = model.generate(**input_ids, cache_implementation=None)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/generation/utils.py", line 2326, in generate
    result = self._sample(
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/generation/utils.py", line 3286, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 1336, in forward
    outputs = self.model(
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 1083, in forward
    layer_outputs = decoder_layer(
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 831, in forward
    attn_output, self_attn_weights, present_key_value, attn_router_logits = self.self_attention(
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 637, in forward
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  File "/data2/g00523483/.conda/envs/ddd/lib/python3.10/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 489, in apply_rotary_pos_emb
    q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 3

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions