Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pardata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
init,
list_all_datasets,
load_dataset,
load_dataset_from_location,
load_schema_collections)
from ._version import version as __version__
69 changes: 68 additions & 1 deletion pardata/_high_level.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2020 IBM Corp. All Rights Reserved.
# Copyright 2020--2021 IBM Corp. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,15 +23,19 @@
from copy import deepcopy
import dataclasses
import functools
import hashlib
from textwrap import dedent
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar, Union, cast
import os
from packaging.version import parse as version_parser
import re

from ._config import Config
from ._dataset import Dataset
from . import typing as typing_
from ._schema import (DatasetSchemaCollection, FormatSchemaCollection, LicenseSchemaCollection,
SchemaDict, SchemaCollectionManager)
from ._schema_retrieval import is_url

# Global configurations --------------------------------------------------

Expand Down Expand Up @@ -208,6 +212,69 @@ def load_dataset(name: str, *,
f'\nCaused by:\n{e}')


def load_dataset_from_location(url_or_path: Union[str, typing_.PathLike], *,
schema: Optional[SchemaDict] = None,
force_redownload: bool = False) -> Dict[str, Any]:
""" Load the dataset from ``url_or_path``. This function is equivalent to calling :class:`~pardata.Dataset`, where
``schema['download_url']`` is set to ``url_or_path``. In the returned :class:`dict` object, keys corresponding to
empty values are removed (unlike :meth:`~pardata.Dataset.load`).

:param url_or_path: The URL or path of the dataset archive.
:param schema: The schema used for loading the dataset. If ``None``, it is set to a default schema that is designed
to accommodate most common use cases.
:param force_redownload: ``True`` if to force redownloading the dataset.
:return: A dictionary that holds the dataset. It is structured the same as the return value of :func:`load_dataset`.
"""

if not is_url(str(url_or_path)):
url_or_path = os.path.abspath(url_or_path) # Don't use pathlib.Path.resolve because it resolves symlinks
url_or_path = cast(str, url_or_path)

# Name of the data dir: {url_or_path with non-alphanums replaced by dashes}-sha512. The sha512 suffix is there to
# prevent collision.
data_dir_name = (f'{re.sub("[^0-9a-zA-Z]+", "-", url_or_path)}-'
f'{hashlib.sha512(url_or_path.encode("utf-8")).hexdigest()}')
data_dir = get_config().DATADIR / '_location_direct' / data_dir_name
if schema is None:
# Construct the default schema
schema = {
'name': 'Direct from a location',
'description': 'Loaded directly from a location',
'subdatasets': {
}
}

RegexFormatPair = namedtuple('RegexFormatPair', ['regex', 'format'])
regex_format_pairs = (
RegexFormatPair(regex=r'.*\.csv', format='table/csv'),
RegexFormatPair(regex=r'.*\.wav', format='audio/wav'),
RegexFormatPair(regex=r'.*\.(txt|log)', format='text/plain'),
RegexFormatPair(regex=r'.*\.(jpg|jpeg)', format='image/jpeg'),
RegexFormatPair(regex=r'.*\.png', format='image/png'),
)

for regex_format_pair in regex_format_pairs:
schema['subdatasets'][regex_format_pair.format] = {
'format': {
'id': regex_format_pair.format,
},
'path': {
'type': 'regex',
'value': regex_format_pair.regex
}
}
schema['download_url'] = url_or_path

dataset = Dataset(schema=schema, data_dir=data_dir, mode=Dataset.InitializationMode.LAZY)
if force_redownload or not dataset.is_downloaded():
dataset.download(check=False, # Already checked by `is_downloaded` call above
verify_checksum=False)
dataset.load()

# strip empty values
return {k: v for k, v in dataset.data.items() if len(v) > 0}


@_handle_name_param
@_handle_version_param
def get_dataset_metadata(name: str, *, version: str = 'latest') -> SchemaDict:
Expand Down
32 changes: 31 additions & 1 deletion tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pydantic import ValidationError

from pardata import (describe_dataset, export_schema_collections, get_config, get_dataset_metadata, init,
list_all_datasets, load_dataset, load_schema_collections)
list_all_datasets, load_dataset, load_dataset_from_location, load_schema_collections)
from pardata.dataset import Dataset
from pardata._config import Config
from pardata._high_level import _get_schema_collections
Expand Down Expand Up @@ -204,6 +204,36 @@ def test_loading_undownloaded(self, tmp_path):
'(by calling this function with `download=True` for at least once)?') in str(e.value)


class TestLoadDatasetFromLocation:
"Test ``load_dataset_from_location."

def test_loading_dataset_from_path(self, downloaded_gmb_dataset, dataset_dir):
for force_redownload in ('False', 'False', 'True'):
data = load_dataset_from_location(dataset_dir / 'gmb-1.0.2.zip', force_redownload=force_redownload)
assert frozenset(data.keys()) == frozenset(('text/plain',))
assert frozenset(data['text/plain'].keys()) == frozenset((
'groningen_meaning_bank_modified/gmb_subset_full.txt',
'groningen_meaning_bank_modified/LICENSE.txt',
'groningen_meaning_bank_modified/README.txt'
))

def test_loading_dataset_from_url(self, gmb_schema):
for force_redownload in ('False', 'False', 'True'):
data = load_dataset_from_location(gmb_schema['download_url'], force_redownload=force_redownload)
assert frozenset(data.keys()) == frozenset(('text/plain',))
assert frozenset(data['text/plain'].keys()) == frozenset((
'groningen_meaning_bank_modified/gmb_subset_full.txt',
'groningen_meaning_bank_modified/LICENSE.txt',
'groningen_meaning_bank_modified/README.txt'
))

def test_custom_schema(self, gmb_schema):
data = load_dataset_from_location(gmb_schema['download_url'], schema=gmb_schema)
assert frozenset(data.keys()) == frozenset(('gmb_subset_full',))
assert data['gmb_subset_full'].startswith('Masked VBN O\n')
assert data['gmb_subset_full'].endswith('. . O\n\n')


def test_get_dataset_metadata():
"Test ``get_dataset_metadata``."

Expand Down