Skip to content

Commit

Permalink
add test for dataset loader
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Aug 22, 2022
1 parent deab952 commit 1b88f07
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 9 deletions.
9 changes: 2 additions & 7 deletions client/starwhale/api/_impl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,8 @@ def from_env(cls, name: str = "") -> S3LinkAuth:
_secret_name = cls._SECRET_FMT.format(name=_name)
_access_name = cls._ACCESS_KEY_FMT.format(name=_name)

_secret = _env(_secret_name)
_access = _env(_access_name)
if not _secret or not _access:
raise FieldTypeOrValueError(
f"cannot find secret[{_secret_name}] or access[{_access_name}] key env"
)

_secret = _env(_secret_name, "")
_access = _env(_access_name, "")
return cls(
name,
_access,
Expand Down
7 changes: 6 additions & 1 deletion client/starwhale/api/_impl/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def __init__(
**kw: t.Any,
) -> None:
self.bucket = bucket
self.conn: t.Optional[ObjectStoreS3Connection]

self.backend: StorageBackend
if backend == SWDSBackendType.S3:
Expand All @@ -164,6 +163,12 @@ def __init__(

self.key_prefix = key_prefix or os.environ.get("SW_OBJECT_STORE_KEY_PREFIX", "")

def __str__(self) -> str:
return f"DatasetObjectStore backend:{self.backend}"

def __repr__(self) -> str:
return f"DatasetObjectStore backend:{self.backend}, bucket:{self.bucket}, key_prefix:{self.key_prefix}"

@classmethod
def from_data_link_uri(cls, data_uri: str, auth_name: str) -> DatasetObjectStore:
data_uri = data_uri.strip()
Expand Down
359 changes: 359 additions & 0 deletions client/tests/sdk/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
import os
import shutil
from unittest.mock import patch, MagicMock

from pyfakefs.fake_filesystem_unittest import TestCase

from starwhale.consts import AUTH_ENV_FNAME, SWDSBackendType
from starwhale.base.uri import URI
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.base.type import URIType, DataFormatType, DataOriginType, ObjectStoreType
from starwhale.api._impl.loader import (
get_data_loader,
SWDSBinDataLoader,
UserRawDataLoader,
)
from starwhale.api._impl.dataset import MIMEType, S3LinkAuth, TabularDatasetRow
from starwhale.core.dataset.store import DatasetStorage
from starwhale.core.dataset.dataset import DatasetSummary

from .. import ROOT_DIR


class TestDataLoader(TestCase):
def setUp(self) -> None:
self.setUpPyfakefs()
self.dataset_uri = URI("mnist/version/1122334455667788", URIType.DATASET)
self.swds_dir = os.path.join(ROOT_DIR, "data", "dataset", "swds")
self.fs.add_real_directory(self.swds_dir)

@patch("starwhale.core.dataset.model.StandaloneDataset.summary")
@patch("starwhale.api._impl.loader.TabularDataset.scan")
def test_user_raw_local_store(
self, m_scan: MagicMock, m_summary: MagicMock
) -> None:
m_summary.return_value = DatasetSummary(
include_user_raw=True,
include_link=False,
)
loader = get_data_loader(self.dataset_uri)
assert isinstance(loader, UserRawDataLoader)

fname = "data"
m_scan.return_value = [
TabularDatasetRow(
id=0,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=16,
data_size=784,
label=b"0",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.UNDEFINED,
data_mime_type=MIMEType.UNDEFINED,
auth_name="",
)
]

raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data")
self.fs.add_real_file(raw_data_fpath)
data_dir = DatasetStorage(self.dataset_uri).data_dir
ensure_dir(data_dir)
shutil.copy(raw_data_fpath, str(data_dir / fname))

assert loader._stores == {}

rows = list(loader)
assert len(rows) == 1

_data, _label = rows[0]

assert _label.idx == 0
assert _label.data_size == 1
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"}
assert _data.data_size == len(_data.data)
assert len(_data.data) == 28 * 28

assert loader.kind == DataFormatType.USER_RAW
assert list(loader._stores.keys()) == ["local."]
assert loader._stores["local."].bucket == str(data_dir)
assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE
assert not loader._stores["local."].key_prefix

@patch.dict(os.environ, {})
@patch("starwhale.api._impl.loader.boto3.resource")
@patch("starwhale.core.dataset.model.StandaloneDataset.summary")
@patch("starwhale.api._impl.loader.TabularDataset.scan")
def test_user_raw_remote_store(
self,
m_scan: MagicMock,
m_summary: MagicMock,
m_boto3: MagicMock,
) -> None:
m_summary.return_value = DatasetSummary(
include_user_raw=True,
include_link=True,
)

snapshot_workdir = DatasetStorage(self.dataset_uri).snapshot_workdir
ensure_dir(snapshot_workdir)
envs = {
"USER.S3.SERVER1.SECRET": "11",
"USER.S3.SERVER1.ACCESS_KEY": "11",
"USER.S3.SERVER2.SECRET": "11",
"USER.S3.SERVER2.ACCESS_KEY": "11",
"USER.S3.SERVER2.ENDPOINT": "127.0.0.1:19000",
}
os.environ.update(envs)
auth_env = S3LinkAuth.from_env(name="server1").dump_env()
auth_env.extend(S3LinkAuth.from_env(name="server2").dump_env())
ensure_file(
snapshot_workdir / AUTH_ENV_FNAME,
content="\n".join(auth_env),
)

for k in envs:
os.environ.pop(k)

loader = get_data_loader(self.dataset_uri)
assert isinstance(loader, UserRawDataLoader)
assert loader.kind == DataFormatType.USER_RAW
for k in envs:
assert k in os.environ

version = "1122334455667788"

m_scan.return_value = [
TabularDatasetRow(
id=0,
object_store_type=ObjectStoreType.REMOTE,
data_uri=f"s3://127.0.0.1:9000@starwhale/project/2/dataset/11/{version}",
data_offset=16,
data_size=784,
label=b"0",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.USER_RAW,
data_mime_type=MIMEType.GRAYSCALE,
auth_name="server1",
),
TabularDatasetRow(
id=1,
object_store_type=ObjectStoreType.REMOTE,
data_uri=f"s3://127.0.0.1:19000@starwhale/project/2/dataset/11/{version}",
data_offset=16,
data_size=784,
label=b"1",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.USER_RAW,
data_mime_type=MIMEType.GRAYSCALE,
auth_name="server2",
),
TabularDatasetRow(
id=2,
object_store_type=ObjectStoreType.REMOTE,
data_uri=f"s3://starwhale/project/2/dataset/11/{version}",
data_offset=16,
data_size=784,
label=b"1",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.USER_RAW,
data_mime_type=MIMEType.GRAYSCALE,
auth_name="server2",
),
TabularDatasetRow(
id=3,
object_store_type=ObjectStoreType.REMOTE,
data_uri=f"s3://username:password@127.0.0.1:29000@starwhale/project/2/dataset/11/{version}",
data_offset=16,
data_size=784,
label=b"1",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.USER_RAW,
data_mime_type=MIMEType.GRAYSCALE,
auth_name="server3",
),
]

raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data")
self.fs.add_real_file(raw_data_fpath)
with open(raw_data_fpath, "rb") as f:
raw_content = f.read(-1)

m_boto3.return_value = MagicMock(
**{
"Object.return_value": MagicMock(
**{
"get.return_value": {
"Body": MagicMock(**{"read.return_value": raw_content}),
"ContentLength": len(raw_content),
}
}
)
}
)

assert loader.kind == DataFormatType.USER_RAW
assert loader._stores == {}

rows = list(loader)
assert len(rows) == 4

_data, _label = rows[0]
assert _label.idx == 0
assert _label.data == b"0"
assert len(_data.data) == 28 * 28
assert len(_data.data) == _data.data_size
assert len(loader._stores) == 3
assert loader._stores["remote.server1"].backend.kind == SWDSBackendType.S3
assert loader._stores["remote.server1"].bucket == "starwhale"

@patch.dict(os.environ, {})
@patch("starwhale.api._impl.loader.boto3.resource")
@patch("starwhale.core.dataset.model.CloudDataset.summary")
@patch("starwhale.api._impl.loader.TabularDataset.scan")
def test_swds_bin_s3(
self, m_scan: MagicMock, m_summary: MagicMock, m_boto3: MagicMock
) -> None:
m_summary.return_value = DatasetSummary(
include_user_raw=False,
include_link=False,
)
version = "1122334455667788"
dataset_uri = URI(
f"http://127.0.0.1:1234/project/self/dataset/mnist/version/{version}",
expected_type=URIType.DATASET,
)
loader = get_data_loader(dataset_uri)
assert isinstance(loader, SWDSBinDataLoader)
assert loader.kind == DataFormatType.SWDS_BIN

fname = "data_ubyte_0.swds_bin"
m_scan.return_value = [
TabularDatasetRow(
id=0,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=0,
data_size=8160,
label=b"0",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.SWDS_BIN,
data_mime_type=MIMEType.UNDEFINED,
auth_name="",
)
]
os.environ.update(
{
"SW_S3_BUCKET": "starwhale",
"SW_OBJECT_STORE_KEY_PREFIX": f"project/self/dataset/mnist/version/11/{version}",
"SW_S3_ENDPOINT": "starwhale.mock:9000",
"SW_S3_ACCESS_KEY": "foo",
"SW_S3_SECRET": "bar",
}
)

with open(os.path.join(self.swds_dir, fname), "rb") as f:
swds_content = f.read(-1)

m_boto3.return_value = MagicMock(
**{
"Object.return_value": MagicMock(
**{
"get.return_value": {
"Body": MagicMock(**{"read.return_value": swds_content}),
"ContentLength": len(swds_content),
}
}
)
}
)
assert loader._stores == {}

rows = list(loader)
assert len(rows) == 1
_data, _label = rows[0]
assert _label.idx == 0
assert _label.data == b"0"
assert _label.data_size == 1

assert len(_data.data) == _data.data_size
assert _data.data_size == 10 * 28 * 28
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": version}

assert list(loader._stores.keys()) == ["local."]
backend = loader._stores["local."].backend
assert backend.kind == SWDSBackendType.S3
assert backend.s3.Object.call_args[0] == (
"starwhale",
f"project/self/dataset/mnist/version/11/{version}/{fname}",
)

assert loader._stores["local."].bucket == "starwhale"
assert (
loader._stores["local."].key_prefix
== f"project/self/dataset/mnist/version/11/{version}"
)

@patch.dict(os.environ, {})
@patch("starwhale.core.dataset.model.StandaloneDataset.summary")
@patch("starwhale.api._impl.loader.TabularDataset.scan")
def test_swds_bin_fuse(self, m_scan: MagicMock, m_summary: MagicMock) -> None:
m_summary.return_value = DatasetSummary(
include_user_raw=False,
include_link=False,
rows=2,
increased_rows=2,
)
loader = get_data_loader(self.dataset_uri)
assert isinstance(loader, SWDSBinDataLoader)
assert loader.kind == DataFormatType.SWDS_BIN

fname = "data_ubyte_0.swds_bin"
m_scan.return_value = [
TabularDatasetRow(
id=0,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=0,
data_size=8160,
label=b"0",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.SWDS_BIN,
data_mime_type=MIMEType.UNDEFINED,
auth_name="",
),
TabularDatasetRow(
id=1,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=0,
data_size=8160,
label=b"1",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.SWDS_BIN,
data_mime_type=MIMEType.UNDEFINED,
auth_name="",
),
]

data_dir = DatasetStorage(self.dataset_uri).data_dir
ensure_dir(data_dir)
shutil.copyfile(os.path.join(self.swds_dir, fname), str(data_dir / fname))
assert loader._stores == {}

rows = list(loader)
assert len(rows) == 2

_data, _label = rows[0]
assert _label.idx == 0
assert _label.data == b"0"
assert _label.data_size == 1

assert len(_data.data) == _data.data_size
assert _data.data_size == 10 * 28 * 28
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"}

assert list(loader._stores.keys()) == ["local."]
assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE
assert loader._stores["local."].bucket == str(data_dir)
assert not loader._stores["local."].key_prefix
Loading

0 comments on commit 1b88f07

Please sign in to comment.