1
+ import functools
2
+ import io
3
+ from typing import Optional , Tuple
1
4
from unittest .mock import Mock , patch
2
5
3
6
import pytest
8
11
from oumi .core .types .turn import Conversation , Message , Role , Type
9
12
10
13
14
+ class EqBytesIO :
15
+ def __init__ (self , bytes_io : io .BytesIO ):
16
+ self ._byte_io = bytes_io
17
+
18
+ def __eq__ (self , other ):
19
+ return (
20
+ isinstance (other , io .BytesIO )
21
+ and other .getvalue () == self ._byte_io .getvalue ()
22
+ )
23
+
24
+
11
25
@pytest .fixture
12
26
def mock_processor ():
13
27
processor = Mock ()
@@ -22,8 +36,18 @@ def mock_processor():
22
36
return processor
23
37
24
38
39
+ @functools .lru_cache (maxsize = None ) # same as @cache added in Python 3.9
40
+ def _create_test_image (image_size : Optional [Tuple [int , int ]] = None ) -> bytes :
41
+ if image_size is None :
42
+ image_size = (80 , 40 )
43
+ image = Image .new (mode = "RGBA" , size = image_size )
44
+ bytes_io = io .BytesIO ()
45
+ image .save (bytes_io , format = "PNG" )
46
+ return bytes_io .getvalue ()
47
+
48
+
25
49
@pytest .fixture
26
- def sample_conversation ():
50
+ def sample_conversation_using_image_path ():
27
51
return Conversation (
28
52
messages = [
29
53
Message (role = Role .USER , content = "Describe this image:" , type = Type .TEXT ),
@@ -38,36 +62,101 @@ def sample_conversation():
38
62
39
63
40
64
@pytest .fixture
41
- def test_dataset (mock_processor , sample_conversation ):
42
- class TestDataset (VisionLanguageSftDataset ):
65
+ def sample_conversation_using_image_binary ():
66
+ return Conversation (
67
+ messages = [
68
+ Message (role = Role .USER , content = "Describe this image:" , type = Type .TEXT ),
69
+ Message (
70
+ role = Role .USER , binary = _create_test_image (), type = Type .IMAGE_BINARY
71
+ ),
72
+ Message (
73
+ role = Role .ASSISTANT ,
74
+ content = "A beautiful sunset over the ocean." ,
75
+ type = Type .TEXT ,
76
+ ),
77
+ ]
78
+ )
79
+
80
+
81
+ @pytest .fixture
82
+ def test_dataset_using_image_path (
83
+ mock_processor : Mock , sample_conversation_using_image_path : Conversation
84
+ ):
85
+ class TestDatasetImagePath (VisionLanguageSftDataset ):
86
+ default_dataset = "custom"
87
+
88
+ def transform_conversation (self , example ):
89
+ return sample_conversation_using_image_path
90
+
91
+ def _load_data (self ):
92
+ pass
93
+
94
+ return TestDatasetImagePath (processor = mock_processor )
95
+
96
+
97
+ @pytest .fixture
98
+ def test_dataset_using_image_binary (
99
+ mock_processor : Mock , sample_conversation_using_image_binary : Conversation
100
+ ):
101
+ class TestDatasetImageBinary (VisionLanguageSftDataset ):
43
102
default_dataset = "custom"
44
103
45
104
def transform_conversation (self , example ):
46
- return sample_conversation
105
+ return sample_conversation_using_image_binary
47
106
48
107
def _load_data (self ):
49
108
pass
50
109
51
- return TestDataset (processor = mock_processor )
110
+ return TestDatasetImageBinary (processor = mock_processor )
52
111
53
112
54
- def test_transform_image ( test_dataset ):
113
+ def test_transform_image_using_image_path ( test_dataset_using_image_path ):
55
114
with patch ("PIL.Image.open" ) as mock_open :
56
115
mock_image = Mock (spec = Image .Image )
57
116
mock_open .return_value .convert .return_value = mock_image
58
117
59
- test_dataset .transform_image ("path/to/image.jpg" )
118
+ test_dataset_using_image_path .transform_image ("path/to/image.jpg" )
60
119
61
120
mock_open .assert_called_once_with ("path/to/image.jpg" )
62
- test_dataset ._image_processor .assert_called_once ()
121
+ test_dataset_using_image_path ._image_processor .assert_called_once ()
63
122
64
123
65
- def test_transform_simple_model (test_dataset ):
66
- with patch .object (test_dataset , "_load_image" ) as mock_load_image :
124
+ def test_transform_image_using_image_binary (test_dataset_using_image_binary ):
125
+ with patch ("PIL.Image.open" ) as mock_open :
126
+ mock_image = Mock (spec = Image .Image )
127
+ mock_open .return_value .convert .return_value = mock_image
128
+
129
+ test_image_bytes = _create_test_image ()
130
+ test_dataset_using_image_binary .transform_image (
131
+ Message (type = Type .IMAGE_BINARY , binary = test_image_bytes , role = Role .USER )
132
+ )
133
+
134
+ mock_open .assert_called_once_with (EqBytesIO (io .BytesIO (test_image_bytes )))
135
+ test_dataset_using_image_binary ._image_processor .assert_called_once ()
136
+
137
+
138
+ def test_transform_simple_model_using_image_path (test_dataset_using_image_path ):
139
+ with patch .object (test_dataset_using_image_path , "_load_image" ) as mock_load_image :
140
+ mock_image = Mock (spec = Image .Image )
141
+ mock_load_image .return_value = mock_image
142
+
143
+ result = test_dataset_using_image_path .transform ({"example" : "data" })
144
+
145
+ assert isinstance (result , dict )
146
+ assert "input_ids" in result
147
+ assert "attention_mask" in result
148
+ assert "labels" in result
149
+ assert "pixel_values" in result
150
+
151
+
152
+ def test_transform_simple_model_using_image_binary (test_dataset_using_image_binary ):
153
+ with patch .object (
154
+ test_dataset_using_image_binary , "_load_image"
155
+ ) as mock_load_image :
67
156
mock_image = Mock (spec = Image .Image )
68
157
mock_load_image .return_value = mock_image
69
158
70
- result = test_dataset .transform ({"example" : "data" })
159
+ result = test_dataset_using_image_binary .transform ({"example" : "data" })
71
160
72
161
assert isinstance (result , dict )
73
162
assert "input_ids" in result
@@ -76,15 +165,39 @@ def test_transform_simple_model(test_dataset):
76
165
assert "pixel_values" in result
77
166
78
167
79
- def test_transform_instruct_model (test_dataset , mock_processor ):
168
+ def test_transform_instruct_model_using_image_path (
169
+ test_dataset_using_image_path , mock_processor : Mock
170
+ ):
171
+ mock_processor .chat_template = "Template"
172
+ mock_processor .apply_chat_template = Mock (return_value = "Processed template" )
173
+
174
+ with patch .object (test_dataset_using_image_path , "_load_image" ) as mock_load_image :
175
+ mock_image = Mock (spec = Image .Image )
176
+ mock_load_image .return_value = mock_image
177
+
178
+ result = test_dataset_using_image_path .transform ({"example" : "data" })
179
+
180
+ assert isinstance (result , dict )
181
+ assert "input_ids" in result
182
+ assert "attention_mask" in result
183
+ assert "labels" in result
184
+ assert "pixel_values" in result
185
+ mock_processor .apply_chat_template .assert_called_once ()
186
+
187
+
188
+ def test_transform_instruct_model_using_image_binary (
189
+ test_dataset_using_image_binary , mock_processor : Mock
190
+ ):
80
191
mock_processor .chat_template = "Template"
81
192
mock_processor .apply_chat_template = Mock (return_value = "Processed template" )
82
193
83
- with patch .object (test_dataset , "_load_image" ) as mock_load_image :
194
+ with patch .object (
195
+ test_dataset_using_image_binary , "_load_image"
196
+ ) as mock_load_image :
84
197
mock_image = Mock (spec = Image .Image )
85
198
mock_load_image .return_value = mock_image
86
199
87
- result = test_dataset .transform ({"example" : "data" })
200
+ result = test_dataset_using_image_binary .transform ({"example" : "data" })
88
201
89
202
assert isinstance (result , dict )
90
203
assert "input_ids" in result
0 commit comments