Skip to content

Commit 251825a

Browse files
SunMarcArthurZucker
authored andcommitted
fix dict like init for ModelOutput (#41002)
* fix dict like init * style
1 parent 6e1270d commit 251825a

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/transformers/utils/generic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ def __post_init__(self):
380380
# if we provided an iterator as first field and the iterator is a (key, value) iterator
381381
# set the associated fields
382382
if first_field_iterator:
383+
# reset first field to None
384+
setattr(self, class_fields[0].name, None)
383385
for idx, element in enumerate(iterator):
384386
if not isinstance(element, (list, tuple)) or len(element) != 2 or not isinstance(element[0], str):
385387
if idx == 0:

tests/utils/test_generic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020

2121
from transformers.configuration_utils import PretrainedConfig
22-
from transformers.modeling_outputs import BaseModelOutput
22+
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
2323
from transformers.testing_utils import require_torch
2424
from transformers.utils import (
2525
can_return_tuple,
@@ -139,6 +139,19 @@ def test_to_py_obj_torch(self):
139139

140140
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
141141

142+
def test_model_output_subclass(self):
143+
# testing with “dict-like init” case
144+
out = CausalLMOutputWithPast({"logits": torch.ones(2, 3, 4)})
145+
self.assertTrue(out["logits"] is not None)
146+
self.assertTrue(out.loss is None)
147+
self.assertTrue(len(out.to_tuple()) == 1)
148+
149+
# testing with dataclass init case
150+
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
151+
self.assertTrue(out["logits"] is not None)
152+
self.assertTrue(out.loss is None)
153+
self.assertTrue(len(out.to_tuple()) == 1)
154+
142155

143156
class ValidationDecoratorTester(unittest.TestCase):
144157
def test_cases_no_warning(self):

0 commit comments

Comments
 (0)