Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug(datastore): fix scan table path argument error #908

Merged
merged 1 commit into from
Aug 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
from loguru import logger
from typing_extensions import Protocol

from starwhale.utils.fs import ensure_dir
Expand Down Expand Up @@ -317,7 +318,9 @@ def nextItem(self) -> None:

def _get_table_files(path: str) -> List[str]:
if not os.path.exists(path):
logger.warning(f"not find path {path} as table file path")
return []

if not os.path.isdir(path):
raise RuntimeError(f"{path} is not a directory")

Expand Down Expand Up @@ -640,17 +643,20 @@ def __init__(
key_column_type: pa.DataType,
columns: Optional[Dict[str, str]],
explicit_none: bool,
path: str,
) -> None:
self.name = name
self.key_column_type = key_column_type
self.columns = columns
self.explicit_none = explicit_none
self.path = path

infos: List[TableInfo] = []
for table_name, table_alias, explicit_none in tables:
table_path = _get_table_path(self.root_path, table_name)
table = self.tables.get(table_name, None)
if table is None:
schema = _read_table_schema(_get_table_path(self.root_path, table_name))
schema = _read_table_schema(table_path)
else:
schema = table.get_schema()
key_column_type = schema.columns[schema.key_column].type.pa_type
Expand All @@ -667,7 +673,9 @@ def __init__(
alias = columns.get(col_prefix + name, alias)
if alias != "":
cols[name] = alias
infos.append(TableInfo(table_name, key_column_type, cols, explicit_none))
infos.append(
TableInfo(table_name, key_column_type, cols, explicit_none, table_path)
)

# check for key type conflictions
for info in infos:
Expand All @@ -689,7 +697,7 @@ def __init__(
)
else:
iters.append(
_scan_table(info.name, info.columns, start, end, info.explicit_none)
_scan_table(info.path, info.columns, start, end, info.explicit_none)
)

for record in _merge_scan(iters):
Expand Down