-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
deab952
commit 1b88f07
Showing
4 changed files
with
401 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,359 @@ | ||
import os | ||
import shutil | ||
from unittest.mock import patch, MagicMock | ||
|
||
from pyfakefs.fake_filesystem_unittest import TestCase | ||
|
||
from starwhale.consts import AUTH_ENV_FNAME, SWDSBackendType | ||
from starwhale.base.uri import URI | ||
from starwhale.utils.fs import ensure_dir, ensure_file | ||
from starwhale.base.type import URIType, DataFormatType, DataOriginType, ObjectStoreType | ||
from starwhale.api._impl.loader import ( | ||
get_data_loader, | ||
SWDSBinDataLoader, | ||
UserRawDataLoader, | ||
) | ||
from starwhale.api._impl.dataset import MIMEType, S3LinkAuth, TabularDatasetRow | ||
from starwhale.core.dataset.store import DatasetStorage | ||
from starwhale.core.dataset.dataset import DatasetSummary | ||
|
||
from .. import ROOT_DIR | ||
|
||
|
||
class TestDataLoader(TestCase): | ||
def setUp(self) -> None: | ||
self.setUpPyfakefs() | ||
self.dataset_uri = URI("mnist/version/1122334455667788", URIType.DATASET) | ||
self.swds_dir = os.path.join(ROOT_DIR, "data", "dataset", "swds") | ||
self.fs.add_real_directory(self.swds_dir) | ||
|
||
@patch("starwhale.core.dataset.model.StandaloneDataset.summary") | ||
@patch("starwhale.api._impl.loader.TabularDataset.scan") | ||
def test_user_raw_local_store( | ||
self, m_scan: MagicMock, m_summary: MagicMock | ||
) -> None: | ||
m_summary.return_value = DatasetSummary( | ||
include_user_raw=True, | ||
include_link=False, | ||
) | ||
loader = get_data_loader(self.dataset_uri) | ||
assert isinstance(loader, UserRawDataLoader) | ||
|
||
fname = "data" | ||
m_scan.return_value = [ | ||
TabularDatasetRow( | ||
id=0, | ||
object_store_type=ObjectStoreType.LOCAL, | ||
data_uri=fname, | ||
data_offset=16, | ||
data_size=784, | ||
label=b"0", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.UNDEFINED, | ||
data_mime_type=MIMEType.UNDEFINED, | ||
auth_name="", | ||
) | ||
] | ||
|
||
raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data") | ||
self.fs.add_real_file(raw_data_fpath) | ||
data_dir = DatasetStorage(self.dataset_uri).data_dir | ||
ensure_dir(data_dir) | ||
shutil.copy(raw_data_fpath, str(data_dir / fname)) | ||
|
||
assert loader._stores == {} | ||
|
||
rows = list(loader) | ||
assert len(rows) == 1 | ||
|
||
_data, _label = rows[0] | ||
|
||
assert _label.idx == 0 | ||
assert _label.data_size == 1 | ||
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} | ||
assert _data.data_size == len(_data.data) | ||
assert len(_data.data) == 28 * 28 | ||
|
||
assert loader.kind == DataFormatType.USER_RAW | ||
assert list(loader._stores.keys()) == ["local."] | ||
assert loader._stores["local."].bucket == str(data_dir) | ||
assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE | ||
assert not loader._stores["local."].key_prefix | ||
|
||
@patch.dict(os.environ, {}) | ||
@patch("starwhale.api._impl.loader.boto3.resource") | ||
@patch("starwhale.core.dataset.model.StandaloneDataset.summary") | ||
@patch("starwhale.api._impl.loader.TabularDataset.scan") | ||
def test_user_raw_remote_store( | ||
self, | ||
m_scan: MagicMock, | ||
m_summary: MagicMock, | ||
m_boto3: MagicMock, | ||
) -> None: | ||
m_summary.return_value = DatasetSummary( | ||
include_user_raw=True, | ||
include_link=True, | ||
) | ||
|
||
snapshot_workdir = DatasetStorage(self.dataset_uri).snapshot_workdir | ||
ensure_dir(snapshot_workdir) | ||
envs = { | ||
"USER.S3.SERVER1.SECRET": "11", | ||
"USER.S3.SERVER1.ACCESS_KEY": "11", | ||
"USER.S3.SERVER2.SECRET": "11", | ||
"USER.S3.SERVER2.ACCESS_KEY": "11", | ||
"USER.S3.SERVER2.ENDPOINT": "127.0.0.1:19000", | ||
} | ||
os.environ.update(envs) | ||
auth_env = S3LinkAuth.from_env(name="server1").dump_env() | ||
auth_env.extend(S3LinkAuth.from_env(name="server2").dump_env()) | ||
ensure_file( | ||
snapshot_workdir / AUTH_ENV_FNAME, | ||
content="\n".join(auth_env), | ||
) | ||
|
||
for k in envs: | ||
os.environ.pop(k) | ||
|
||
loader = get_data_loader(self.dataset_uri) | ||
assert isinstance(loader, UserRawDataLoader) | ||
assert loader.kind == DataFormatType.USER_RAW | ||
for k in envs: | ||
assert k in os.environ | ||
|
||
version = "1122334455667788" | ||
|
||
m_scan.return_value = [ | ||
TabularDatasetRow( | ||
id=0, | ||
object_store_type=ObjectStoreType.REMOTE, | ||
data_uri=f"s3://127.0.0.1:9000@starwhale/project/2/dataset/11/{version}", | ||
data_offset=16, | ||
data_size=784, | ||
label=b"0", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.USER_RAW, | ||
data_mime_type=MIMEType.GRAYSCALE, | ||
auth_name="server1", | ||
), | ||
TabularDatasetRow( | ||
id=1, | ||
object_store_type=ObjectStoreType.REMOTE, | ||
data_uri=f"s3://127.0.0.1:19000@starwhale/project/2/dataset/11/{version}", | ||
data_offset=16, | ||
data_size=784, | ||
label=b"1", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.USER_RAW, | ||
data_mime_type=MIMEType.GRAYSCALE, | ||
auth_name="server2", | ||
), | ||
TabularDatasetRow( | ||
id=2, | ||
object_store_type=ObjectStoreType.REMOTE, | ||
data_uri=f"s3://starwhale/project/2/dataset/11/{version}", | ||
data_offset=16, | ||
data_size=784, | ||
label=b"1", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.USER_RAW, | ||
data_mime_type=MIMEType.GRAYSCALE, | ||
auth_name="server2", | ||
), | ||
TabularDatasetRow( | ||
id=3, | ||
object_store_type=ObjectStoreType.REMOTE, | ||
data_uri=f"s3://username:password@127.0.0.1:29000@starwhale/project/2/dataset/11/{version}", | ||
data_offset=16, | ||
data_size=784, | ||
label=b"1", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.USER_RAW, | ||
data_mime_type=MIMEType.GRAYSCALE, | ||
auth_name="server3", | ||
), | ||
] | ||
|
||
raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data") | ||
self.fs.add_real_file(raw_data_fpath) | ||
with open(raw_data_fpath, "rb") as f: | ||
raw_content = f.read(-1) | ||
|
||
m_boto3.return_value = MagicMock( | ||
**{ | ||
"Object.return_value": MagicMock( | ||
**{ | ||
"get.return_value": { | ||
"Body": MagicMock(**{"read.return_value": raw_content}), | ||
"ContentLength": len(raw_content), | ||
} | ||
} | ||
) | ||
} | ||
) | ||
|
||
assert loader.kind == DataFormatType.USER_RAW | ||
assert loader._stores == {} | ||
|
||
rows = list(loader) | ||
assert len(rows) == 4 | ||
|
||
_data, _label = rows[0] | ||
assert _label.idx == 0 | ||
assert _label.data == b"0" | ||
assert len(_data.data) == 28 * 28 | ||
assert len(_data.data) == _data.data_size | ||
assert len(loader._stores) == 3 | ||
assert loader._stores["remote.server1"].backend.kind == SWDSBackendType.S3 | ||
assert loader._stores["remote.server1"].bucket == "starwhale" | ||
|
||
@patch.dict(os.environ, {}) | ||
@patch("starwhale.api._impl.loader.boto3.resource") | ||
@patch("starwhale.core.dataset.model.CloudDataset.summary") | ||
@patch("starwhale.api._impl.loader.TabularDataset.scan") | ||
def test_swds_bin_s3( | ||
self, m_scan: MagicMock, m_summary: MagicMock, m_boto3: MagicMock | ||
) -> None: | ||
m_summary.return_value = DatasetSummary( | ||
include_user_raw=False, | ||
include_link=False, | ||
) | ||
version = "1122334455667788" | ||
dataset_uri = URI( | ||
f"http://127.0.0.1:1234/project/self/dataset/mnist/version/{version}", | ||
expected_type=URIType.DATASET, | ||
) | ||
loader = get_data_loader(dataset_uri) | ||
assert isinstance(loader, SWDSBinDataLoader) | ||
assert loader.kind == DataFormatType.SWDS_BIN | ||
|
||
fname = "data_ubyte_0.swds_bin" | ||
m_scan.return_value = [ | ||
TabularDatasetRow( | ||
id=0, | ||
object_store_type=ObjectStoreType.LOCAL, | ||
data_uri=fname, | ||
data_offset=0, | ||
data_size=8160, | ||
label=b"0", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.SWDS_BIN, | ||
data_mime_type=MIMEType.UNDEFINED, | ||
auth_name="", | ||
) | ||
] | ||
os.environ.update( | ||
{ | ||
"SW_S3_BUCKET": "starwhale", | ||
"SW_OBJECT_STORE_KEY_PREFIX": f"project/self/dataset/mnist/version/11/{version}", | ||
"SW_S3_ENDPOINT": "starwhale.mock:9000", | ||
"SW_S3_ACCESS_KEY": "foo", | ||
"SW_S3_SECRET": "bar", | ||
} | ||
) | ||
|
||
with open(os.path.join(self.swds_dir, fname), "rb") as f: | ||
swds_content = f.read(-1) | ||
|
||
m_boto3.return_value = MagicMock( | ||
**{ | ||
"Object.return_value": MagicMock( | ||
**{ | ||
"get.return_value": { | ||
"Body": MagicMock(**{"read.return_value": swds_content}), | ||
"ContentLength": len(swds_content), | ||
} | ||
} | ||
) | ||
} | ||
) | ||
assert loader._stores == {} | ||
|
||
rows = list(loader) | ||
assert len(rows) == 1 | ||
_data, _label = rows[0] | ||
assert _label.idx == 0 | ||
assert _label.data == b"0" | ||
assert _label.data_size == 1 | ||
|
||
assert len(_data.data) == _data.data_size | ||
assert _data.data_size == 10 * 28 * 28 | ||
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": version} | ||
|
||
assert list(loader._stores.keys()) == ["local."] | ||
backend = loader._stores["local."].backend | ||
assert backend.kind == SWDSBackendType.S3 | ||
assert backend.s3.Object.call_args[0] == ( | ||
"starwhale", | ||
f"project/self/dataset/mnist/version/11/{version}/{fname}", | ||
) | ||
|
||
assert loader._stores["local."].bucket == "starwhale" | ||
assert ( | ||
loader._stores["local."].key_prefix | ||
== f"project/self/dataset/mnist/version/11/{version}" | ||
) | ||
|
||
@patch.dict(os.environ, {}) | ||
@patch("starwhale.core.dataset.model.StandaloneDataset.summary") | ||
@patch("starwhale.api._impl.loader.TabularDataset.scan") | ||
def test_swds_bin_fuse(self, m_scan: MagicMock, m_summary: MagicMock) -> None: | ||
m_summary.return_value = DatasetSummary( | ||
include_user_raw=False, | ||
include_link=False, | ||
rows=2, | ||
increased_rows=2, | ||
) | ||
loader = get_data_loader(self.dataset_uri) | ||
assert isinstance(loader, SWDSBinDataLoader) | ||
assert loader.kind == DataFormatType.SWDS_BIN | ||
|
||
fname = "data_ubyte_0.swds_bin" | ||
m_scan.return_value = [ | ||
TabularDatasetRow( | ||
id=0, | ||
object_store_type=ObjectStoreType.LOCAL, | ||
data_uri=fname, | ||
data_offset=0, | ||
data_size=8160, | ||
label=b"0", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.SWDS_BIN, | ||
data_mime_type=MIMEType.UNDEFINED, | ||
auth_name="", | ||
), | ||
TabularDatasetRow( | ||
id=1, | ||
object_store_type=ObjectStoreType.LOCAL, | ||
data_uri=fname, | ||
data_offset=0, | ||
data_size=8160, | ||
label=b"1", | ||
data_origin=DataOriginType.NEW, | ||
data_format=DataFormatType.SWDS_BIN, | ||
data_mime_type=MIMEType.UNDEFINED, | ||
auth_name="", | ||
), | ||
] | ||
|
||
data_dir = DatasetStorage(self.dataset_uri).data_dir | ||
ensure_dir(data_dir) | ||
shutil.copyfile(os.path.join(self.swds_dir, fname), str(data_dir / fname)) | ||
assert loader._stores == {} | ||
|
||
rows = list(loader) | ||
assert len(rows) == 2 | ||
|
||
_data, _label = rows[0] | ||
assert _label.idx == 0 | ||
assert _label.data == b"0" | ||
assert _label.data_size == 1 | ||
|
||
assert len(_data.data) == _data.data_size | ||
assert _data.data_size == 10 * 28 * 28 | ||
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} | ||
|
||
assert list(loader._stores.keys()) == ["local."] | ||
assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE | ||
assert loader._stores["local."].bucket == str(data_dir) | ||
assert not loader._stores["local."].key_prefix |
Oops, something went wrong.