Skip to content

Commit

Permalink
Lazy arff (#1346)
Browse files Browse the repository at this point in the history
* Prefer parquet over arff, do not load arff if not needed

* Only download arff if needed

* Test arff file is not set when downloading parquet from prod
  • Loading branch information
PGijsbers authored Sep 16, 2024
1 parent fa7e9db commit b4d038f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
40 changes: 25 additions & 15 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ def _download_data(self) -> None:
# import required here to avoid circular import.
from .functions import _get_dataset_arff, _get_dataset_parquet

self.data_file = str(_get_dataset_arff(self))
if self._parquet_url is not None:
self.parquet_file = str(_get_dataset_parquet(self))
if self.parquet_file is None:
self.data_file = str(_get_dataset_arff(self))

def _get_arff(self, format: str) -> dict: # noqa: A002
"""Read ARFF file and return decoded arff.
Expand Down Expand Up @@ -535,18 +536,7 @@ def _cache_compressed_file_from_file(
feather_attribute_file,
) = self._compressed_cache_file_paths(data_file)

if data_file.suffix == ".arff":
data, categorical, attribute_names = self._parse_data_from_arff(data_file)
elif data_file.suffix == ".pq":
try:
data = pd.read_parquet(data_file)
except Exception as e: # noqa: BLE001
raise Exception(f"File: {data_file}") from e

categorical = [data[c].dtype.name == "category" for c in data.columns]
attribute_names = list(data.columns)
else:
raise ValueError(f"Unknown file type for file '{data_file}'.")
attribute_names, categorical, data = self._parse_data_from_file(data_file)

# Feather format does not work for sparse datasets, so we use pickle for sparse datasets
if scipy.sparse.issparse(data):
Expand All @@ -572,6 +562,24 @@ def _cache_compressed_file_from_file(

return data, categorical, attribute_names

def _parse_data_from_file(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]:
if data_file.suffix == ".arff":
data, categorical, attribute_names = self._parse_data_from_arff(data_file)
elif data_file.suffix == ".pq":
attribute_names, categorical, data = self._parse_data_from_pq(data_file)
else:
raise ValueError(f"Unknown file type for file '{data_file}'.")
return attribute_names, categorical, data

def _parse_data_from_pq(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]:
try:
data = pd.read_parquet(data_file)
except Exception as e: # noqa: BLE001
raise Exception(f"File: {data_file}") from e
categorical = [data[c].dtype.name == "category" for c in data.columns]
attribute_names = list(data.columns)
return attribute_names, categorical, data

def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool], list[str]]: # noqa: PLR0912, C901
"""Load data from compressed format or arff. Download data if not present on disk."""
need_to_create_pickle = self.cache_format == "pickle" and self.data_pickle_file is None
Expand Down Expand Up @@ -636,8 +644,10 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool]
"Please manually delete the cache file if you want OpenML-Python "
"to attempt to reconstruct it.",
)
assert self.data_file is not None
data, categorical, attribute_names = self._parse_data_from_arff(Path(self.data_file))
file_to_load = self.data_file if self.parquet_file is None else self.parquet_file
assert file_to_load is not None
attr, cat, df = self._parse_data_from_file(Path(file_to_load))
return df, cat, attr

data_up_to_date = isinstance(data, pd.DataFrame) or scipy.sparse.issparse(data)
if self.cache_format == "pickle" and not data_up_to_date:
Expand Down
11 changes: 7 additions & 4 deletions openml/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def get_datasets(


@openml.utils.thread_safe_if_oslo_installed
def get_dataset( # noqa: C901, PLR0912
def get_dataset( # noqa: C901, PLR0912, PLR0915
dataset_id: int | str,
download_data: bool | None = None, # Optional for deprecation warning; later again only bool
version: int | None = None,
Expand Down Expand Up @@ -589,7 +589,6 @@ def get_dataset( # noqa: C901, PLR0912
if download_qualities:
qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id)

arff_file = _get_dataset_arff(description) if download_data else None
if "oml:parquet_url" in description and download_data:
try:
parquet_file = _get_dataset_parquet(
Expand All @@ -598,10 +597,14 @@ def get_dataset( # noqa: C901, PLR0912
)
except urllib3.exceptions.MaxRetryError:
parquet_file = None
if parquet_file is None and arff_file:
logger.warning("Failed to download parquet, fallback on ARFF.")
else:
parquet_file = None

arff_file = None
if parquet_file is None and download_data:
logger.warning("Failed to download parquet, fallback on ARFF.")
arff_file = _get_dataset_arff(description)

remove_dataset_cache = False
except OpenMLServerException as e:
# if there was an exception
Expand Down
1 change: 1 addition & 0 deletions tests/test_datasets/test_dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,7 @@ def test_get_dataset_parquet(self):
assert dataset._parquet_url is not None
assert dataset.parquet_file is not None
assert os.path.isfile(dataset.parquet_file)
assert dataset.data_file is None # is alias for arff path

@pytest.mark.production()
def test_list_datasets_with_high_size_parameter(self):
Expand Down

0 comments on commit b4d038f

Please sign in to comment.