Skip to content

Commit

Permalink
feat: replace link auth file to link auth config (#1682)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Jan 6, 2023
1 parent e55db8d commit 039020b
Show file tree
Hide file tree
Showing 20 changed files with 458 additions and 333 deletions.
27 changes: 1 addition & 26 deletions client/starwhale/api/_impl/dataset/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
import jsonlines
from loguru import logger

from starwhale.consts import (
AUTH_ENV_FNAME,
DEFAULT_PROJECT,
STANDALONE_INSTANCE,
SWDS_DATA_FNAME_FMT,
)
from starwhale.consts import DEFAULT_PROJECT, STANDALONE_INSTANCE, SWDS_DATA_FNAME_FMT
from starwhale.base.uri import URI
from starwhale.utils.fs import empty_dir, ensure_dir
from starwhale.base.type import (
Expand All @@ -35,7 +30,6 @@
from starwhale.core.dataset.type import (
Link,
Binary,
LinkAuth,
MIMEType,
BaseArtifact,
DatasetSummary,
Expand Down Expand Up @@ -381,7 +375,6 @@ class UserRawBuildExecutor(BaseBuildExecutor):
def make_swds(self) -> DatasetSummary:
increased_rows = 0
total_data_size = 0
auth_candidates: t.Dict[str, LinkAuth] = {}
include_link = False

map_path_sign: t.Dict[str, t.Tuple[str, Path]] = {}
Expand Down Expand Up @@ -410,7 +403,6 @@ def make_swds(self) -> DatasetSummary:
_data_fpath
)
data_uri, _ = map_path_sign[_data_fpath]
auth = ""
object_store_type = ObjectStoreType.LOCAL

def _travel_link(obj: t.Any) -> None:
Expand Down Expand Up @@ -438,13 +430,6 @@ def _travel_link(obj: t.Any) -> None:
else:
_remote_link = row_data
data_uri = _remote_link.uri
if _remote_link.auth:
auth = _remote_link.auth.name
auth_candidates[
f"{_remote_link.auth.ltype}.{_remote_link.auth.name}"
] = _remote_link.auth
else:
auth = ""
object_store_type = ObjectStoreType.REMOTE
include_link = True

Expand All @@ -457,7 +442,6 @@ def _travel_link(obj: t.Any) -> None:
data_offset=row_data.offset,
data_size=row_data.size,
data_origin=DataOriginType.NEW,
auth_name=auth,
data_type=row_data.astype(),
annotations=row_annotations,
_append_seq_id=append_seq_id,
Expand All @@ -468,7 +452,6 @@ def _travel_link(obj: t.Any) -> None:
increased_rows += 1

self._copy_files(map_path_sign)
self._copy_auth(auth_candidates)
self.tabular_dataset.info = self.get_info() # type: ignore

# TODO: provide fine-grained rows/increased rows by dataset pythonic api
Expand All @@ -489,14 +472,6 @@ def _copy_files(self, map_path_sign: t.Dict[str, t.Tuple[str, Path]]) -> None:
obj_path.absolute()
)

def _copy_auth(self, auth_candidates: t.Dict[str, LinkAuth]) -> None:
if not auth_candidates:
return

with (self.workdir / AUTH_ENV_FNAME).open("w") as f:
for auth in auth_candidates.values():
f.write("\n".join(auth.dump_env()))

@property
def data_format_type(self) -> DataFormatType:
return DataFormatType.USER_RAW
Expand Down
26 changes: 9 additions & 17 deletions client/starwhale/api/_impl/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
import threading
from abc import ABCMeta, abstractmethod
from functools import total_ordering
from urllib.parse import urlparse

import loguru
from loguru import logger as _logger

from starwhale.utils import load_dotenv
from starwhale.consts import HTTPMethod, AUTH_ENV_FNAME
from starwhale.consts import HTTPMethod
from starwhale.base.uri import URI
from starwhale.base.type import URIType, InstanceType, DataFormatType, ObjectStoreType
from starwhale.base.cloud import CloudRequestMixed
from starwhale.utils.error import ParameterError
from starwhale.core.dataset.type import Link, BaseArtifact
from starwhale.core.dataset.store import FileLikeObj, ObjectStore, DatasetStorage
from starwhale.core.dataset.store import FileLikeObj, ObjectStore
from starwhale.api._impl.data_store import SwObject
from starwhale.core.dataset.tabular import (
TabularDataset,
Expand Down Expand Up @@ -106,7 +106,6 @@ def __init__(
)
self.session_consumption = session_consumption
self._stores: t.Dict[str, ObjectStore] = {}
self._load_dataset_auth_env()
self.last_processed_range: t.Optional[t.Tuple[t.Any, t.Any]] = None
self._store_lock = threading.Lock()

Expand All @@ -120,17 +119,12 @@ def __init__(
raise ValueError(f"cache_size({cache_size}) must be a positive int number")
self._cache_size = cache_size

def _load_dataset_auth_env(self) -> None:
# TODO: support multi datasets
if self.dataset_uri.instance_type == InstanceType.STANDALONE:
auth_env_fpath = (
DatasetStorage(self.dataset_uri).snapshot_workdir / AUTH_ENV_FNAME
)
load_dotenv(auth_env_fpath)

def _get_store(self, row: TabularDatasetRow) -> ObjectStore:
with self._store_lock:
_k = f"{self.dataset_uri}.{row.data_link.scheme}.{row.auth_name}"
_up = urlparse(row.data_link.uri)
_parts = _up.path.lstrip("/").split("/", 1)
_cache_key = row.data_link.uri.replace(_parts[-1], "")
_k = f"{self.dataset_uri}.{_cache_key}"
_store = self._stores.get(_k)
if _store:
return _store
Expand All @@ -139,9 +133,7 @@ def _get_store(self, row: TabularDatasetRow) -> ObjectStore:
_store = ObjectStore.to_signed_http_backend(self.dataset_uri)
else:
if row.object_store_type == ObjectStoreType.REMOTE:
_store = ObjectStore.from_data_link_uri(
row.data_link, row.auth_name
)
_store = ObjectStore.from_data_link_uri(row.data_link)
else:
_store = ObjectStore.from_dataset_uri(self.dataset_uri)

Expand Down Expand Up @@ -261,7 +253,7 @@ def _unpack_row(

store = self._get_store(row)
key_compose = self._get_key_compose(row, store)
file = store.backend._make_file(store.bucket, key_compose)
file = store.backend._make_file(key_compose=key_compose, bucket=store.bucket)
data_content, _ = self._read_data(file, row)
data = BaseArtifact.reflect(data_content, row.data_type)
return DataRow(index=row.id, data=data, annotations=row.annotations)
Expand Down
2 changes: 2 additions & 0 deletions client/starwhale/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from starwhale.core.runtime.cli import runtime_cmd
from starwhale.core.instance.cli import instance_cmd

from .cli import config_cmd
from .mngt import add_mngt_command
from .completion import completion_cmd

Expand Down Expand Up @@ -46,6 +47,7 @@ def cli(ctx: click.Context, verbose: bool, output: str) -> None:
cli.add_command(dataset_cmd, aliases=["ds"]) # type: ignore
cli.add_command(open_board)
cli.add_command(completion_cmd)
cli.add_command(config_cmd)
add_mngt_command(cli)

return cli
Expand Down
18 changes: 18 additions & 0 deletions client/starwhale/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import click

from starwhale.utils import config
from starwhale.utils.cli import AliasedGroup


@click.group(
"config",
cls=AliasedGroup,
help="Configuration management, edit is supported now",
)
def config_cmd() -> None:
pass


@config_cmd.command("edit", aliases=["e"], help="edit the configuration of swlci")
def __edit() -> None:
config.edit_from_shell()
1 change: 0 additions & 1 deletion client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,3 @@ class FileNode:
DEFAULT_CONDA_CHANNEL = "conda-forge"

WHEEL_FILE_EXTENSION = ".whl"
AUTH_ENV_FNAME = ".auth_env"
118 changes: 79 additions & 39 deletions client/starwhale/core/dataset/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
FieldTypeOrValueError,
)
from starwhale.utils.retry import http_retry
from starwhale.utils.config import SWCliConfigMixed
from starwhale.utils.config import SWCliConfigMixed, get_swcli_config_path
from starwhale.core.dataset.type import Link

# TODO: refactor Dataset and ModelPackage LocalStorage
Expand Down Expand Up @@ -187,21 +187,29 @@ def close(self) -> None:


class S3Connection:
connections_config: t.List[S3Connection] = []
init_config_lock = threading.Lock()
supported_schemes = {"s3", "minio", "aliyun", "oss"}
DEFAULT_CONNECT_TIMEOUT = 10.0
DEFAULT_READ_TIMEOUT = 50.0
DEFAULT_MAX_ATTEMPTS = 6

def __init__(
self,
endpoint: str,
access_key: str,
secret_key: str,
region: str = "",
bucket: str = "",
connect_timeout: float = 10.0,
read_timeout: float = 50.0,
total_max_attempts: int = 6,
connect_timeout: float = DEFAULT_CONNECT_TIMEOUT,
read_timeout: float = DEFAULT_READ_TIMEOUT,
total_max_attempts: int = DEFAULT_MAX_ATTEMPTS,
) -> None:
self.endpoint = endpoint.strip()
if self.endpoint and not self.endpoint.startswith(("http://", "https://")):
self.endpoint = f"http://{self.endpoint}"

r = urlparse(self.endpoint)
self.endpoint_loc = r.netloc.split("@")[-1]
self.access_key = access_key
self.secret_key = secret_key
self.region = region
Expand All @@ -218,55 +226,88 @@ def __init__(
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
self.extra_s3_configs = json.loads(os.environ.get("SW_S3_EXTRA_CONFIGS", "{}"))

def fits(self, bucket: str, endpoint_loc: str) -> bool:
return self.bucket == bucket and self.endpoint_loc == endpoint_loc

def __str__(self) -> str:
return f"endpoint[{self.endpoint}]-region[{self.region}]"

__repr__ = __str__

@classmethod
def from_uri(cls, uri: str, auth_name: str = "") -> S3Connection:
def from_uri(cls, uri: str) -> S3Connection:
"""make S3 Connection by uri
uri:
- s3://username:password@127.0.0.1:8000/bucket/key
- s3://127.0.0.1:8000/bucket/key
"""
from .type import S3LinkAuth

uri = uri.strip()
if not uri or not uri.startswith(("s3://", "minio://")):
r = urlparse(uri.strip())
if r.scheme not in cls.supported_schemes:
raise NoSupportError(
f"s3 connection only support s3:// prefix, the actual uri is {uri}"
f"s3 connection only support {cls.supported_schemes} prefix, the actual uri is {uri}"
)

r = urlparse(uri)

link_auth = S3LinkAuth.from_env(auth_name)
access = r.username or link_auth.access_key
secret = r.password or link_auth.secret
region = link_auth.region
endpoint = r.netloc.split("@")[-1] or link_auth.endpoint

endpoint_loc = r.netloc.split("@")[-1]
parts = r.path.lstrip("/").split("/", 1)
if len(parts) != 2 or parts[0] == "" or parts[1] == "":
raise FieldTypeOrValueError(
f"{uri} is not a valid s3 uri for bucket and key"
)
bucket = parts[0]
cls.build_init_connections()
for connection in S3Connection.connections_config:
if connection.fits(bucket, endpoint_loc):
return connection
if r.username and r.password:
return cls(
endpoint=endpoint_loc,
access_key=r.username,
secret_key=r.password,
region=_DEFAULT_S3_REGION,
bucket=bucket,
)
raise NoSupportError(f"no matching s3 config in {get_swcli_config_path()}")

if not endpoint:
raise FieldTypeOrValueError("endpoint is empty")

if not access or not secret:
raise FieldTypeOrValueError("no access_key or secret_key")

return cls(
endpoint=endpoint,
access_key=access,
secret_key=secret,
region=region or _DEFAULT_S3_REGION,
bucket=bucket,
)
@classmethod
def build_init_connections(cls) -> None:
with S3Connection.init_config_lock:
if S3Connection.connections_config:
return
sw_config = SWCliConfigMixed()
link_auths = sw_config.link_auths
if not link_auths:
return
for la in link_auths:
if not la.get("type") or type(la.get("type")) != str:
continue
if la.get("type") not in cls.supported_schemes:
continue
if (
not la.get("endpoint")
or not la.get("ak")
or not la.get("sk")
or not la.get("bucket")
):
continue
S3Connection.connections_config.append(
cls(
endpoint=la.get("endpoint"),
access_key=la.get("ak"),
secret_key=la.get("sk"),
region=la.get("region") or _DEFAULT_S3_REGION,
bucket=la.get("bucket"),
connect_timeout=la.get(
"connect_timeout", cls.DEFAULT_CONNECT_TIMEOUT
),
read_timeout=la.get("read_timeout", cls.DEFAULT_READ_TIMEOUT),
total_max_attempts=la.get(
"total_max_attempts", cls.DEFAULT_MAX_ATTEMPTS
),
)
)
env_conn = cls.from_env()
if env_conn:
S3Connection.connections_config.append(env_conn)

@classmethod
def from_env(cls) -> S3Connection:
Expand Down Expand Up @@ -312,14 +353,14 @@ def __repr__(self) -> str:
return f"ObjectStored:{self.backend}, bucket:{self.bucket}, key_prefix:{self.key_prefix}"

@classmethod
def from_data_link_uri(cls, data_link: Link, auth_name: str) -> ObjectStore:
def from_data_link_uri(cls, data_link: Link) -> ObjectStore:
if not data_link:
raise FieldTypeOrValueError("data_link is empty")

# TODO: support other uri type
if data_link.scheme in ("s3", "minio", "oss", "aliyun"):
if data_link.scheme in S3Connection.supported_schemes:
backend = SWDSBackendType.S3
conn = S3Connection.from_uri(data_link.uri, auth_name)
conn = S3Connection.from_uri(data_link.uri)
bucket = conn.bucket
elif data_link.scheme in ["http", "https"]:
backend = SWDSBackendType.Http
Expand Down Expand Up @@ -368,7 +409,6 @@ def _make_file(


class S3StorageBackend(StorageBackend):

lock_s3_creation = threading.Lock()

def __init__(
Expand Down Expand Up @@ -438,7 +478,7 @@ def __init__(self, dataset_uri: URI) -> None:

@http_retry
def _make_file(
self, auth: str, key_compose: t.Tuple[Link, int, int]
self, key_compose: t.Tuple[Link, int, int], **kwargs: t.Any
) -> FileLikeObj:
_key, _start, _end = key_compose
return HttpBufferedFileLike(
Expand Down Expand Up @@ -469,7 +509,7 @@ def __init__(self) -> None:

@http_retry
def _make_file(
self, auth: str, key_compose: t.Tuple[Link, int, int]
self, key_compose: t.Tuple[Link, int, int], **kwargs: t.Any
) -> FileLikeObj:
_key, _start, _end = key_compose
return HttpBufferedFileLike(
Expand Down
Loading

0 comments on commit 039020b

Please sign in to comment.