Skip to content

Commit 9bb2c01

Browse files
committed
clean
1 parent ad3d69b commit 9bb2c01

File tree

2 files changed

+133
-19
lines changed

2 files changed

+133
-19
lines changed

src/oumi/datasets/vision_language/coco_captions.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def transform_conversation(self, example: dict) -> Conversation:
1616
if required_key not in example:
1717
raise ValueError(
1818
f"Example doesn't contain '{required_key}' key. "
19-
f"Available keys: {example.keys()}"
19+
f"Available keys: {example.keys()}."
2020
)
2121

2222
if "raw" not in example["sentences"]:
2323
raise ValueError(
24-
f"Example doesn't contain 'sentences.raw' key. "
25-
f"Available keys under 'sentences.': {example['sentences'].keys()}"
24+
f"Training example doesn't contain 'sentences.raw' key. "
25+
f"Available keys under 'sentences.': {example['sentences'].keys()}."
2626
)
2727
output_text = example["sentences"]["raw"]
2828

@@ -46,8 +46,9 @@ def transform_conversation(self, example: dict) -> Conversation:
4646
)
4747
else:
4848
raise ValueError(
49-
f"Example contains none of required keys: 'image.bytes', 'image.path'. "
50-
f"Available keys under 'image.': {example['image'].keys()}"
49+
"Training example contains none of required keys: "
50+
"'image.bytes', 'image.path'. "
51+
f"Available keys under 'image.': {example['image'].keys()}."
5152
)
5253

5354
messages.append(Message(role=Role.ASSISTANT, content=output_text))

tests/core/datasets/test_vision_language_dataset.py

+127-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import functools
2+
import io
3+
from typing import Optional, Tuple
14
from unittest.mock import Mock, patch
25

36
import pytest
@@ -8,6 +11,17 @@
811
from oumi.core.types.turn import Conversation, Message, Role, Type
912

1013

14+
class EqBytesIO:
15+
def __init__(self, bytes_io: io.BytesIO):
16+
self._byte_io = bytes_io
17+
18+
def __eq__(self, other):
19+
return (
20+
isinstance(other, io.BytesIO)
21+
and other.getvalue() == self._byte_io.getvalue()
22+
)
23+
24+
1125
@pytest.fixture
1226
def mock_processor():
1327
processor = Mock()
@@ -22,8 +36,18 @@ def mock_processor():
2236
return processor
2337

2438

39+
@functools.lru_cache(maxsize=None) # same as @cache added in Python 3.9
40+
def _create_test_image(image_size: Optional[Tuple[int, int]] = None) -> bytes:
41+
if image_size is None:
42+
image_size = (80, 40)
43+
image = Image.new(mode="RGBA", size=image_size)
44+
bytes_io = io.BytesIO()
45+
image.save(bytes_io, format="PNG")
46+
return bytes_io.getvalue()
47+
48+
2549
@pytest.fixture
26-
def sample_conversation():
50+
def sample_conversation_using_image_path():
2751
return Conversation(
2852
messages=[
2953
Message(role=Role.USER, content="Describe this image:", type=Type.TEXT),
@@ -38,36 +62,101 @@ def sample_conversation():
3862

3963

4064
@pytest.fixture
41-
def test_dataset(mock_processor, sample_conversation):
42-
class TestDataset(VisionLanguageSftDataset):
65+
def sample_conversation_using_image_binary():
66+
return Conversation(
67+
messages=[
68+
Message(role=Role.USER, content="Describe this image:", type=Type.TEXT),
69+
Message(
70+
role=Role.USER, binary=_create_test_image(), type=Type.IMAGE_BINARY
71+
),
72+
Message(
73+
role=Role.ASSISTANT,
74+
content="A beautiful sunset over the ocean.",
75+
type=Type.TEXT,
76+
),
77+
]
78+
)
79+
80+
81+
@pytest.fixture
82+
def test_dataset_using_image_path(
83+
mock_processor: Mock, sample_conversation_using_image_path: Conversation
84+
):
85+
class TestDatasetImagePath(VisionLanguageSftDataset):
86+
default_dataset = "custom"
87+
88+
def transform_conversation(self, example):
89+
return sample_conversation_using_image_path
90+
91+
def _load_data(self):
92+
pass
93+
94+
return TestDatasetImagePath(processor=mock_processor)
95+
96+
97+
@pytest.fixture
98+
def test_dataset_using_image_binary(
99+
mock_processor: Mock, sample_conversation_using_image_binary: Conversation
100+
):
101+
class TestDatasetImageBinary(VisionLanguageSftDataset):
43102
default_dataset = "custom"
44103

45104
def transform_conversation(self, example):
46-
return sample_conversation
105+
return sample_conversation_using_image_binary
47106

48107
def _load_data(self):
49108
pass
50109

51-
return TestDataset(processor=mock_processor)
110+
return TestDatasetImageBinary(processor=mock_processor)
52111

53112

54-
def test_transform_image(test_dataset):
113+
def test_transform_image_using_image_path(test_dataset_using_image_path):
55114
with patch("PIL.Image.open") as mock_open:
56115
mock_image = Mock(spec=Image.Image)
57116
mock_open.return_value.convert.return_value = mock_image
58117

59-
test_dataset.transform_image("path/to/image.jpg")
118+
test_dataset_using_image_path.transform_image("path/to/image.jpg")
60119

61120
mock_open.assert_called_once_with("path/to/image.jpg")
62-
test_dataset._image_processor.assert_called_once()
121+
test_dataset_using_image_path._image_processor.assert_called_once()
63122

64123

65-
def test_transform_simple_model(test_dataset):
66-
with patch.object(test_dataset, "_load_image") as mock_load_image:
124+
def test_transform_image_using_image_binary(test_dataset_using_image_binary):
125+
with patch("PIL.Image.open") as mock_open:
126+
mock_image = Mock(spec=Image.Image)
127+
mock_open.return_value.convert.return_value = mock_image
128+
129+
test_image_bytes = _create_test_image()
130+
test_dataset_using_image_binary.transform_image(
131+
Message(type=Type.IMAGE_BINARY, binary=test_image_bytes, role=Role.USER)
132+
)
133+
134+
mock_open.assert_called_once_with(EqBytesIO(io.BytesIO(test_image_bytes)))
135+
test_dataset_using_image_binary._image_processor.assert_called_once()
136+
137+
138+
def test_transform_simple_model_using_image_path(test_dataset_using_image_path):
139+
with patch.object(test_dataset_using_image_path, "_load_image") as mock_load_image:
140+
mock_image = Mock(spec=Image.Image)
141+
mock_load_image.return_value = mock_image
142+
143+
result = test_dataset_using_image_path.transform({"example": "data"})
144+
145+
assert isinstance(result, dict)
146+
assert "input_ids" in result
147+
assert "attention_mask" in result
148+
assert "labels" in result
149+
assert "pixel_values" in result
150+
151+
152+
def test_transform_simple_model_using_image_binary(test_dataset_using_image_binary):
153+
with patch.object(
154+
test_dataset_using_image_binary, "_load_image"
155+
) as mock_load_image:
67156
mock_image = Mock(spec=Image.Image)
68157
mock_load_image.return_value = mock_image
69158

70-
result = test_dataset.transform({"example": "data"})
159+
result = test_dataset_using_image_binary.transform({"example": "data"})
71160

72161
assert isinstance(result, dict)
73162
assert "input_ids" in result
@@ -76,15 +165,39 @@ def test_transform_simple_model(test_dataset):
76165
assert "pixel_values" in result
77166

78167

79-
def test_transform_instruct_model(test_dataset, mock_processor):
168+
def test_transform_instruct_model_using_image_path(
169+
test_dataset_using_image_path, mock_processor: Mock
170+
):
171+
mock_processor.chat_template = "Template"
172+
mock_processor.apply_chat_template = Mock(return_value="Processed template")
173+
174+
with patch.object(test_dataset_using_image_path, "_load_image") as mock_load_image:
175+
mock_image = Mock(spec=Image.Image)
176+
mock_load_image.return_value = mock_image
177+
178+
result = test_dataset_using_image_path.transform({"example": "data"})
179+
180+
assert isinstance(result, dict)
181+
assert "input_ids" in result
182+
assert "attention_mask" in result
183+
assert "labels" in result
184+
assert "pixel_values" in result
185+
mock_processor.apply_chat_template.assert_called_once()
186+
187+
188+
def test_transform_instruct_model_using_image_binary(
189+
test_dataset_using_image_binary, mock_processor: Mock
190+
):
80191
mock_processor.chat_template = "Template"
81192
mock_processor.apply_chat_template = Mock(return_value="Processed template")
82193

83-
with patch.object(test_dataset, "_load_image") as mock_load_image:
194+
with patch.object(
195+
test_dataset_using_image_binary, "_load_image"
196+
) as mock_load_image:
84197
mock_image = Mock(spec=Image.Image)
85198
mock_load_image.return_value = mock_image
86199

87-
result = test_dataset.transform({"example": "data"})
200+
result = test_dataset_using_image_binary.transform({"example": "data"})
88201

89202
assert isinstance(result, dict)
90203
assert "input_ids" in result

0 commit comments

Comments
 (0)