From 1b88f07ba23f2c7d90e09c9131ea0e235784e93e Mon Sep 17 00:00:00 2001 From: tianweidut Date: Mon, 22 Aug 2022 11:05:14 +0800 Subject: [PATCH] add test for dataset loader --- client/starwhale/api/_impl/dataset.py | 9 +- client/starwhale/api/_impl/loader.py | 7 +- client/tests/sdk/test_loader.py | 359 ++++++++++++++++++++++++++ client/tests/utils/test_common.py | 35 ++- 4 files changed, 401 insertions(+), 9 deletions(-) create mode 100644 client/tests/sdk/test_loader.py diff --git a/client/starwhale/api/_impl/dataset.py b/client/starwhale/api/_impl/dataset.py index 18896b0a7d..324a59ca71 100644 --- a/client/starwhale/api/_impl/dataset.py +++ b/client/starwhale/api/_impl/dataset.py @@ -174,13 +174,8 @@ def from_env(cls, name: str = "") -> S3LinkAuth: _secret_name = cls._SECRET_FMT.format(name=_name) _access_name = cls._ACCESS_KEY_FMT.format(name=_name) - _secret = _env(_secret_name) - _access = _env(_access_name) - if not _secret or not _access: - raise FieldTypeOrValueError( - f"cannot find secret[{_secret_name}] or access[{_access_name}] key env" - ) - + _secret = _env(_secret_name, "") + _access = _env(_access_name, "") return cls( name, _access, diff --git a/client/starwhale/api/_impl/loader.py b/client/starwhale/api/_impl/loader.py index 6f508357c2..bbbdd114b6 100644 --- a/client/starwhale/api/_impl/loader.py +++ b/client/starwhale/api/_impl/loader.py @@ -153,7 +153,6 @@ def __init__( **kw: t.Any, ) -> None: self.bucket = bucket - self.conn: t.Optional[ObjectStoreS3Connection] self.backend: StorageBackend if backend == SWDSBackendType.S3: @@ -164,6 +163,12 @@ def __init__( self.key_prefix = key_prefix or os.environ.get("SW_OBJECT_STORE_KEY_PREFIX", "") + def __str__(self) -> str: + return f"DatasetObjectStore backend:{self.backend}" + + def __repr__(self) -> str: + return f"DatasetObjectStore backend:{self.backend}, bucket:{self.bucket}, key_prefix:{self.key_prefix}" + @classmethod def from_data_link_uri(cls, data_uri: str, auth_name: str) -> DatasetObjectStore: data_uri = data_uri.strip() diff --git a/client/tests/sdk/test_loader.py b/client/tests/sdk/test_loader.py new file mode 100644 index 0000000000..87da1dcfec --- /dev/null +++ b/client/tests/sdk/test_loader.py @@ -0,0 +1,359 @@ +import os +import shutil +from unittest.mock import patch, MagicMock + +from pyfakefs.fake_filesystem_unittest import TestCase + +from starwhale.consts import AUTH_ENV_FNAME, SWDSBackendType +from starwhale.base.uri import URI +from starwhale.utils.fs import ensure_dir, ensure_file +from starwhale.base.type import URIType, DataFormatType, DataOriginType, ObjectStoreType +from starwhale.api._impl.loader import ( + get_data_loader, + SWDSBinDataLoader, + UserRawDataLoader, +) +from starwhale.api._impl.dataset import MIMEType, S3LinkAuth, TabularDatasetRow +from starwhale.core.dataset.store import DatasetStorage +from starwhale.core.dataset.dataset import DatasetSummary + +from .. import ROOT_DIR + + +class TestDataLoader(TestCase): + def setUp(self) -> None: + self.setUpPyfakefs() + self.dataset_uri = URI("mnist/version/1122334455667788", URIType.DATASET) + self.swds_dir = os.path.join(ROOT_DIR, "data", "dataset", "swds") + self.fs.add_real_directory(self.swds_dir) + + @patch("starwhale.core.dataset.model.StandaloneDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_user_raw_local_store( + self, m_scan: MagicMock, m_summary: MagicMock + ) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=True, + include_link=False, + ) + loader = get_data_loader(self.dataset_uri) + assert isinstance(loader, UserRawDataLoader) + + fname = "data" + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=16, + data_size=784, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.UNDEFINED, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ) + ] + + raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data") + self.fs.add_real_file(raw_data_fpath) + data_dir = DatasetStorage(self.dataset_uri).data_dir + ensure_dir(data_dir) + shutil.copy(raw_data_fpath, str(data_dir / fname)) + + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 1 + + _data, _label = rows[0] + + assert _label.idx == 0 + assert _label.data_size == 1 + assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} + assert _data.data_size == len(_data.data) + assert len(_data.data) == 28 * 28 + + assert loader.kind == DataFormatType.USER_RAW + assert list(loader._stores.keys()) == ["local."] + assert loader._stores["local."].bucket == str(data_dir) + assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE + assert not loader._stores["local."].key_prefix + + @patch.dict(os.environ, {}) + @patch("starwhale.api._impl.loader.boto3.resource") + @patch("starwhale.core.dataset.model.StandaloneDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_user_raw_remote_store( + self, + m_scan: MagicMock, + m_summary: MagicMock, + m_boto3: MagicMock, + ) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=True, + include_link=True, + ) + + snapshot_workdir = DatasetStorage(self.dataset_uri).snapshot_workdir + ensure_dir(snapshot_workdir) + envs = { + "USER.S3.SERVER1.SECRET": "11", + "USER.S3.SERVER1.ACCESS_KEY": "11", + "USER.S3.SERVER2.SECRET": "11", + "USER.S3.SERVER2.ACCESS_KEY": "11", + "USER.S3.SERVER2.ENDPOINT": "127.0.0.1:19000", + } + os.environ.update(envs) + auth_env = S3LinkAuth.from_env(name="server1").dump_env() + auth_env.extend(S3LinkAuth.from_env(name="server2").dump_env()) + ensure_file( + snapshot_workdir / AUTH_ENV_FNAME, + content="\n".join(auth_env), + ) + + for k in envs: + os.environ.pop(k) + + loader = get_data_loader(self.dataset_uri) + assert isinstance(loader, UserRawDataLoader) + assert loader.kind == DataFormatType.USER_RAW + for k in envs: + assert k in os.environ + + version = "1122334455667788" + + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://127.0.0.1:9000@starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server1", + ), + TabularDatasetRow( + id=1, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://127.0.0.1:19000@starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server2", + ), + TabularDatasetRow( + id=2, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server2", + ), + TabularDatasetRow( + id=3, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://username:password@127.0.0.1:29000@starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server3", + ), + ] + + raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data") + self.fs.add_real_file(raw_data_fpath) + with open(raw_data_fpath, "rb") as f: + raw_content = f.read(-1) + + m_boto3.return_value = MagicMock( + **{ + "Object.return_value": MagicMock( + **{ + "get.return_value": { + "Body": MagicMock(**{"read.return_value": raw_content}), + "ContentLength": len(raw_content), + } + } + ) + } + ) + + assert loader.kind == DataFormatType.USER_RAW + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 4 + + _data, _label = rows[0] + assert _label.idx == 0 + assert _label.data == b"0" + assert len(_data.data) == 28 * 28 + assert len(_data.data) == _data.data_size + assert len(loader._stores) == 3 + assert loader._stores["remote.server1"].backend.kind == SWDSBackendType.S3 + assert loader._stores["remote.server1"].bucket == "starwhale" + + @patch.dict(os.environ, {}) + @patch("starwhale.api._impl.loader.boto3.resource") + @patch("starwhale.core.dataset.model.CloudDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_swds_bin_s3( + self, m_scan: MagicMock, m_summary: MagicMock, m_boto3: MagicMock + ) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=False, + include_link=False, + ) + version = "1122334455667788" + dataset_uri = URI( + f"http://127.0.0.1:1234/project/self/dataset/mnist/version/{version}", + expected_type=URIType.DATASET, + ) + loader = get_data_loader(dataset_uri) + assert isinstance(loader, SWDSBinDataLoader) + assert loader.kind == DataFormatType.SWDS_BIN + + fname = "data_ubyte_0.swds_bin" + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=0, + data_size=8160, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.SWDS_BIN, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ) + ] + os.environ.update( + { + "SW_S3_BUCKET": "starwhale", + "SW_OBJECT_STORE_KEY_PREFIX": f"project/self/dataset/mnist/version/11/{version}", + "SW_S3_ENDPOINT": "starwhale.mock:9000", + "SW_S3_ACCESS_KEY": "foo", + "SW_S3_SECRET": "bar", + } + ) + + with open(os.path.join(self.swds_dir, fname), "rb") as f: + swds_content = f.read(-1) + + m_boto3.return_value = MagicMock( + **{ + "Object.return_value": MagicMock( + **{ + "get.return_value": { + "Body": MagicMock(**{"read.return_value": swds_content}), + "ContentLength": len(swds_content), + } + } + ) + } + ) + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 1 + _data, _label = rows[0] + assert _label.idx == 0 + assert _label.data == b"0" + assert _label.data_size == 1 + + assert len(_data.data) == _data.data_size + assert _data.data_size == 10 * 28 * 28 + assert _data.ext_attr == {"ds_name": "mnist", "ds_version": version} + + assert list(loader._stores.keys()) == ["local."] + backend = loader._stores["local."].backend + assert backend.kind == SWDSBackendType.S3 + assert backend.s3.Object.call_args[0] == ( + "starwhale", + f"project/self/dataset/mnist/version/11/{version}/{fname}", + ) + + assert loader._stores["local."].bucket == "starwhale" + assert ( + loader._stores["local."].key_prefix + == f"project/self/dataset/mnist/version/11/{version}" + ) + + @patch.dict(os.environ, {}) + @patch("starwhale.core.dataset.model.StandaloneDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_swds_bin_fuse(self, m_scan: MagicMock, m_summary: MagicMock) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=False, + include_link=False, + rows=2, + increased_rows=2, + ) + loader = get_data_loader(self.dataset_uri) + assert isinstance(loader, SWDSBinDataLoader) + assert loader.kind == DataFormatType.SWDS_BIN + + fname = "data_ubyte_0.swds_bin" + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=0, + data_size=8160, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.SWDS_BIN, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ), + TabularDatasetRow( + id=1, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=0, + data_size=8160, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.SWDS_BIN, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ), + ] + + data_dir = DatasetStorage(self.dataset_uri).data_dir + ensure_dir(data_dir) + shutil.copyfile(os.path.join(self.swds_dir, fname), str(data_dir / fname)) + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 2 + + _data, _label = rows[0] + assert _label.idx == 0 + assert _label.data == b"0" + assert _label.data_size == 1 + + assert len(_data.data) == _data.data_size + assert _data.data_size == 10 * 28 * 28 + assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} + + assert list(loader._stores.keys()) == ["local."] + assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE + assert loader._stores["local."].bucket == str(data_dir) + assert not loader._stores["local."].key_prefix diff --git a/client/tests/utils/test_common.py b/client/tests/utils/test_common.py index ab3dc934aa..b5e1955891 100644 --- a/client/tests/utils/test_common.py +++ b/client/tests/utils/test_common.py @@ -1,6 +1,13 @@ import os +import typing as t +from pathlib import Path +from unittest.mock import patch -from starwhale.utils import validate_obj_name +import pytest +from pyfakefs.fake_filesystem import FakeFilesystem +from pyfakefs.fake_filesystem_unittest import Patcher + +from starwhale.utils import load_dotenv, validate_obj_name from starwhale.consts import ENV_LOG_LEVEL from starwhale.utils.debug import init_logger @@ -24,3 +31,29 @@ def test_logger() -> None: init_logger(3) assert os.environ[ENV_LOG_LEVEL] == "DEBUG" + + +@pytest.fixture +def fake_fs() -> t.Generator[t.Optional[FakeFilesystem], None, None]: + with Patcher() as patcher: + yield patcher.fs + + +@patch.dict(os.environ, {"TEST_ENV": "1"}, clear=True) +def test_load_dotenv(fake_fs: FakeFilesystem) -> None: + content = """ + # this is a comment line + A=1 + B = 2 + c = + ddd + """ + fpath = "/home/starwhale/test/.auth_env" + fake_fs.create_file(fpath, contents=content) + assert os.environ["TEST_ENV"] == "1" + load_dotenv(Path(fpath)) + assert os.environ["A"] == "1" + assert os.environ["B"] == "2" + assert not os.environ["c"] + assert "ddd" not in os.environ + assert len(os.environ) == 4