Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, GraniteMoeHybridForCausalLM

>>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm/PowerMoE-3b")
>>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
>>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-h-tiny")
>>> tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-h-tiny")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,31 @@ def __init__(self, config: GraniteMoeHybridConfig):
# Initialize weights and apply final processing
self.post_init()

def forward(self, **super_kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

Example:

```python
>>> from transformers import AutoTokenizer, GraniteMoeHybridForCausalLM

>>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-h-tiny")
>>> tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-h-tiny")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
return super().forward(**super_kwargs)

def prepare_inputs_for_generation(
self,
input_ids,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3384,7 +3384,7 @@ def _get_test_info():
for frame_info in reversed(captured_frames):
tb = types.TracebackType(tb_next, frame_info.frame, frame_info.frame.f_lasti, frame_info.frame.f_lineno)
tb_next = tb
test_traceback = tb
test_traceback = tb_next

origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"]

Expand Down
33 changes: 16 additions & 17 deletions tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest

import pytest
from parameterized import parameterized
from pytest import mark

from transformers import (
Expand Down Expand Up @@ -354,50 +355,48 @@ def test_config_requires_mamba_or_attention_layers(self):
GraniteMoeHybridConfig(layer_types=["not allowed!"])


# TODO (@alex-jw-brooks) - update this once the model(s) are out
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
@require_torch_accelerator
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
@slow
def test_model_logits(self):
@parameterized.expand([("cpu",)]) # runners crash with `cuda`, prob they have mamba kernels installed
def test_model_logits(self, device):
input_ids = [31390, 631, 4162, 30, 322, 25342, 432, 1875, 43826, 10066, 688, 225]

model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-tiny", device_map="auto")
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-h-tiny", device_map=device)

with torch.no_grad():
out = model(torch.tensor([input_ids]).to(torch_device))
out = model(torch.tensor([input_ids]).to(device))

# fmt: off
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([
[-2.9711, -2.2554, -1.0814, -1.6123, -0.8780, -1.0685, -0.6368, -1.9732, -3.3548, -2.6895, -2.3062, -2.6338]
])
[-0.3543, -1.0256, -0.5118, -0.8711, -0.6722, 0.0736, -1.3630, -0.1100, -1.8382, -1.6288, -1.5097, -0.5010]
], device=device)

torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(EXPECTED_MEAN, out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)

# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = torch.tensor([
[4.0662, 5.9547, 3.5803, 3.1306, 4.3211, 3.8902, 4.6438, 8.5434, 7.5865, 5.1623, 5.2240, 9.2982, 5.9094, 6.8834, 5.7551],
])
[6.5938, 7.2500, 1.6484, 5.2188, 3.5781, 2.5469, 6.1250, 5.1875, 9.5000, 4.6875, 4.7188, 10.7500, 10.3125, 7.8438, 5.5312],
], device=device)
# fmt: on

self.assertTrue(
torch.allclose(
EXPECTED_SLICE.to(torch_device),
EXPECTED_SLICE,
out.logits[0, 0, :15].float(),
atol=1e-3,
rtol=1e-3,
)
)

@slow
def test_model_generation(self):
EXPECTED_TEXT_COMPLETION = (
"Simply put, the theory of relativity states that 1) time is relative, and 2) space is relative. The first"
)
@parameterized.expand([("cpu",)])
def test_model_generation(self, device):
EXPECTED_TEXT_COMPLETION = "Simply put, the theory of relativity states that 1) the laws of physics are the same in all inertial reference frames,"
prompt = "Simply put, the theory of relativity states that "
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-tiny")
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-tiny", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-h-tiny")
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-h-tiny", device_map=device)
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# greedy generation outputs
Expand Down