Skip to content

Commit 9ee4466

Browse files
author
Shogo Ujiie
committed
feat: add validator for dumped object type check
1 parent efe7cfc commit 9ee4466

File tree

3 files changed

+32
-18
lines changed

3 files changed

+32
-18
lines changed

gokart/target.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
from datetime import datetime
66
from glob import glob
77
from logging import getLogger
8-
from typing import Any, Optional
8+
from typing import Any, Callable, Optional
99

1010
import luigi
1111
import numpy as np
1212
import pandas as pd
13-
import pandera as pa
1413

1514
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params
1615
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
@@ -79,11 +78,12 @@ def __init__(
7978
target: luigi.target.FileSystemTarget,
8079
processor: FileProcessor,
8180
task_lock_params: TaskLockParams,
82-
expected_dataframe_type: Optional[pa.DataFrameModel] = None,
81+
validator: Callable[[Any], bool] = lambda x: True,
8382
) -> None:
8483
self._target = target
8584
self._processor = processor
8685
self._task_lock_params = task_lock_params
86+
self._validator = validator
8787

8888
def _exists(self) -> bool:
8989
return self._target.exists()
@@ -94,14 +94,12 @@ def _get_task_lock_params(self) -> TaskLockParams:
9494
def _load(self) -> Any:
9595
with self._target.open('r') as f:
9696
obj = self._processor.load(f)
97-
if self.expected_dataframe_type is not None:
98-
return self.expected_dataframe_type(obj)
97+
self._validator(obj)
9998

10099
return obj
101100

102101
def _dump(self, obj) -> None:
103-
if self.expected_dataframe_type is not None:
104-
self.expected_dataframe_type.validate(obj)
102+
self._validator(obj)
105103

106104
with self._target.open('w') as f:
107105
self._processor.dump(obj, f)
@@ -225,12 +223,13 @@ def make_target(
225223
processor: Optional[FileProcessor] = None,
226224
task_lock_params: Optional[TaskLockParams] = None,
227225
store_index_in_feather: bool = True,
226+
validator: Callable[[Any], bool] = lambda x: True,
228227
) -> TargetOnKart:
229228
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
230229
file_path = _make_file_path(file_path, unique_id)
231230
processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather)
232231
file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather)
233-
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)
232+
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, validator=validator)
234233

235234

236235
def make_model_target(

gokart/task.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,13 @@ def clone(self, cls=None, **kwargs):
193193

194194
return cls(**new_k)
195195

196-
def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart:
196+
def make_target(
197+
self,
198+
relative_file_path: Optional[str] = None,
199+
use_unique_id: bool = True,
200+
processor: Optional[FileProcessor] = None,
201+
validator: Callable[[Any], bool] = lambda x: True,
202+
) -> TargetOnKart:
197203
formatted_relative_file_path = (
198204
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl')
199205
)
@@ -210,7 +216,12 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
210216
)
211217

212218
return gokart.target.make_target(
213-
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
219+
file_path=file_path,
220+
unique_id=unique_id,
221+
processor=processor,
222+
task_lock_params=task_lock_params,
223+
store_index_in_feather=self.store_index_in_feather,
224+
validator=validator,
214225
)
215226

216227
def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:

test/test_target.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import boto3
1010
import luigi
1111
import numpy as np
12-
import pandera as pa
1312
import pandas as pd
13+
import pandera as pa
1414
from matplotlib import pyplot
1515
from moto import mock_aws
1616

@@ -284,36 +284,40 @@ def test_model_target_on_s3(self):
284284

285285

286286
class SingleFileTargetTest(unittest.TestCase):
287-
class DummyDataFrameSchema(pa.DataFrameModel):
288-
a: pa.typing.Series[int] = pa.Field()
289-
290287
def test_typed_target(self):
288+
def validate_dataframe(x):
289+
return isinstance(x, pd.DataFrame)
290+
291291
test_case = pd.DataFrame(dict(a=[1, 2]))
292292

293293
with tempfile.TemporaryDirectory() as temp_dir:
294294
_task_lock_params = None
295-
file_path = os.path.join(temp_dir, 'test.csv')
295+
file_path = os.path.join(temp_dir, 'test.pkl')
296296
processor = make_file_processor(file_path, store_index_in_feather=False)
297297
file_system_target = luigi.LocalTarget(file_path, format=processor.format())
298-
file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=self.DummyDataFrameSchema)
298+
file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, validator=validate_dataframe)
299299

300300
file_target.dump(test_case)
301301
dumped_data = file_target.load()
302302
self.assertIsInstance(dumped_data, self.DummyDataFrameSchema)
303303

304304
def test_invalid_typed_target(self):
305+
def validate_int(x):
306+
return isinstance(x, int)
307+
305308
test_case = pd.DataFrame(dict(a=['1', '2']))
306309

307310
with tempfile.TemporaryDirectory() as temp_dir:
308311
_task_lock_params = None
309312
file_path = os.path.join(temp_dir, 'test.csv')
310313
processor = make_file_processor(file_path, store_index_in_feather=False)
311314
file_system_target = luigi.LocalTarget(file_path, format=processor.format())
312-
file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=self.DummyDataFrameSchema)
315+
file_target = SingleFileTarget(
316+
target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=validate_int
317+
)
313318

314319
with self.assertRaises(pa.errors.SchemaError):
315320
file_target.dump(test_case)
316-
317321

318322

319323
if __name__ == '__main__':

0 commit comments

Comments
 (0)