Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

local datasets #175

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ jobs:
shell: cmd
run: |
pip install pytest
pytest test\util.py test\metadata.py test\integration\dummy.py test\integration\vaswani.py test\formats\
pytest test\util.py test\metadata.py test\integration\dummy.py test\integration\vaswani.py test\formats\ test\test_local.py
env:
PATH: 'C:/Program Files/zlib/bin/'
16 changes: 11 additions & 5 deletions ir_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,26 @@ class EntityType(Enum):
scoreddocs = "scoreddocs"
docpairs = "docpairs"
qlogs = "qlogs"

from . import util
registry = util.Registry()

def load(name):
return registry[name]

from . import lazy_libs
from . import log
from . import util
from . import formats
registry = util.Registry()
from . import datasets
from . import indices
from . import wrappers
from . import commands

Dataset = datasets.base.Dataset
create_local_dataset = datasets.create_local_dataset
delete_local_dataset = datasets.delete_local_dataset
iter_local_datasets = datasets.iter_local_datasets

def load(name):
return registry[name]
Dataset = datasets.base.Dataset


def parent_id(dataset_id: str, entity_type: EntityType) -> str:
Expand Down
2 changes: 2 additions & 0 deletions ir_datasets/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,7 @@
from . import wapo
from . import wikir
from . import trec_fair_2021

from .local import iter_local_datasets, create_local_dataset, delete_local_dataset
from . import trec_cast # must be after wapo,car,msmarco_passage
from . import hc4
18 changes: 18 additions & 0 deletions ir_datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ def has_docpairs(self):
def has_qlogs(self):
return self.has(ir_datasets.EntityType.qlogs)

def handler(self, etype: ir_datasets.EntityType):
etype = ir_datasets.EntityType(etype) # validate & allow strings
return getattr(self, f'{etype.value}_handler')()

def clear_cache(self):
for bapi in self._beta_apis.values():
if hasattr(bapi, 'clear_cache'):
bapi.clear_cache()
self._beta_apis.clear()
for c in self._constituents:
if hasattr(c, 'clear_cache'):
c.clear_cache()


class _BetaPythonApiDocs:
def __init__(self, handler):
Expand Down Expand Up @@ -136,6 +149,11 @@ def lookup_iter(self, doc_ids):
def metadata(self):
return self._handler.docs_metadata()

def clear_cache(self):
if self._docstore is not None and hasattr(self._docstore, 'clear_cache'):
self._docstore.clear_cache()
self._docstore = None


class _BetaPythonApiQueries:
def __init__(self, handler):
Expand Down
296 changes: 296 additions & 0 deletions ir_datasets/datasets/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
from .base import Dataset
import itertools
import pickle
import json
import shutil
import uuid
from typing import NamedTuple
import ir_datasets
from ir_datasets import EntityType
from ir_datasets.indices import Lz4PickleLookup, PickleLz4FullStore
from ir_datasets.formats import BaseDocs, BaseQueries, BaseQrels, BaseScoredDocs, BaseDocPairs, BaseQlogs

logger = ir_datasets.log.easy()

NAME = 'local'

BASE_PATH = ir_datasets.util.home_path()/NAME


class BaseLocal:
def __init__(self, path, etype):
self._path = path
self._etype = etype
self._cls_ = None

def _iter(self):
cls = self._cls()
lz4_frame = ir_datasets.lazy_libs.lz4_frame().frame
with lz4_frame.LZ4FrameFile(self._path/'data', 'rb') as fin:
while fin.peek(1):
yield cls(*pickle.load(fin))

def _count(self):
with (self._path/'count.pkl').open('rb') as fin:
return pickle.load(fin)

def _cls(self):
if self._cls_ is None:
with (self._path/'type.pkl').open('rb') as fin:
name, attrs = pickle.load(fin)
self._cls_ = NamedTuple(name, list(attrs.items()))
return self._cls_

@classmethod
def create(cls, path, records):
path.mkdir(exist_ok=True, parents=True)
records = iter(records)
first = next(records)
if isinstance(first, dict):
EntityCls = NamedTuple('LocalEntity', [(k, type(v)) for k, v in first.items()])
records = itertools.chain([EntityCls(**first)], (EntityCls(**r) for r in records))
else:
EntityCls = type(first)
records = itertools.chain([first], records)
with ir_datasets.util.finialized_file(path/'type.pkl', 'wb') as fout:
pickle.dump((EntityCls.__name__, EntityCls.__annotations__), fout)
count = cls._build_datafile(path/'data', records, EntityCls)
with ir_datasets.util.finialized_file(path/'count.pkl', 'wb') as fout:
pickle.dump(count, fout)

@classmethod
def _build_datafile(cls, path, records, ecls):
count = 0
lz4_frame = ir_datasets.lazy_libs.lz4_frame().frame
with ir_datasets.util.finialized_file(path, 'wb') as raw_fout, \
lz4_frame.LZ4FrameFile(raw_fout, 'wb') as fout:
for record in records:
pickle.dump(tuple(record), fout)
count += 1
return count


class LocalDocs(BaseLocal, BaseDocs):
def __init__(self, path):
super().__init__(path, EntityType.docs)

def docs_iter(self):
return iter(self.docs_store())

def docs_store(self, field='doc_id'):
return PickleLz4FullStore(
path=f'{self._path}/data',
init_iter_fn=None,
data_cls=self._cls(),
lookup_field='doc_id',
index_fields=['doc_id'],
count_hint=self._count,
)

def docs_count(self):
return len(self.docs_store())

def docs_cls(self):
return self._cls()

@classmethod
def _build_datafile(cls, path, records, ecls):
lookup = Lz4PickleLookup(path, ecls, 'doc_id', ['doc_id'])
with lookup.transaction() as trans:
for doc in records:
trans.add(doc)


class LocalQueries(BaseLocal, BaseQueries):
def __init__(self, path):
super().__init__(path, EntityType.queries)

def queries_iter(self):
return self._iter()

def queries_count(self):
return self._count()

def queries_cls(self):
return self._cls()


class LocalQrels(BaseLocal, BaseQrels):
def __init__(self, path):
super().__init__(path, EntityType.qrels)

def qrels_iter(self):
return self._iter()

def qrels_count(self):
return self._count()

def qrels_cls(self):
return self._cls()

def qrels_defs(self):
return {}


class LocalScoredDocs(BaseLocal, BaseScoredDocs):
def __init__(self, path):
super().__init__(path, EntityType.scoreddocs)

def scoreddocs_iter(self):
return self._iter()

def scoreddocs_count(self):
return self._count()

def scoreddocs_cls(self):
return self._cls()


class LocalDocpairs(BaseLocal, BaseDocPairs):
def __init__(self, path):
super().__init__(path, EntityType.docpairs)

def docpairs_iter(self):
return self._iter()

def docpairs_count(self):
return self._count()

def docpairs_cls(self):
return self._cls()


class LocalQlogs(BaseLocal, BaseQlogs):
def __init__(self, path):
super().__init__(path, EntityType.qlogs)

def qlogs_iter(self):
return self._iter()

def qlogs_count(self):
return self._count()

def qlogs_cls(self):
return self._cls()


PROVIDERS = {
EntityType.docs: LocalDocs,
EntityType.queries: LocalQueries,
EntityType.qrels: LocalQrels,
EntityType.scoreddocs: LocalScoredDocs,
EntityType.docpairs: LocalDocpairs,
EntityType.qlogs: LocalQlogs,
}


def create_local_dataset(dataset_id, **sources):
if dataset_id in ir_datasets.registry:
raise KeyError(f'{dataset_id} already in registry; choose another name')
path = str(uuid.uuid4())
ds_path = BASE_PATH/path
components = []
dataset_record = {
'id': dataset_id,
'path': path,
'provides': {},
}
with logger.duration(f'provisioning {dataset_id} to {ds_path}'):
ds_path.mkdir(exist_ok=True, parents=True)
for etype in EntityType:
if etype.value not in sources:
continue
source = sources[etype.value]
if isinstance(source, str):
components.append(ir_datasets.load(source).handler(etype))
dataset_record['provides'][etype.value] = source
else:
e_path = ds_path/etype.value
provider_cls = PROVIDERS[EntityType(etype)]
with logger.duration(f'creating {str(e_path)}'):
provider_cls.create(e_path, source)
components.append(provider_cls(e_path))
dataset_record['provides'][etype.value] = None
del sources[etype.value]
if sources:
raise RuntimeError(f'Unexpected argument(s): {sources.keys()}')
dataset = Dataset(*components)
ir_datasets.registry.register(dataset_id, dataset)
registry_path = (BASE_PATH/'registry.json')
if registry_path.exists():
with registry_path.open('rt') as fin:
registry = json.load(fin)
else:
registry = []
registry.append(dataset_record)
with ir_datasets.util.finialized_file(registry_path, 'wt') as fout:
fout.write('[\n')
for item in registry:
fout.write(' ' + json.dumps(item))
if item != registry[-1]:
fout.write(',')
fout.write('\n')
fout.write(']\n')
return dataset


def delete_local_dataset(dataset_id, remove_files=True):
registry_paths = list(BASE_PATH.glob('registry*.json'))
for registry_path in sorted(registry_paths):
with registry_path.open('rt') as fin:
registry = json.load(fin)
changed = False
for dataset in list(registry):
if dataset['id'] == dataset_id:
registry.remove(dataset)
changed = True
try:
ds = ir_datasets.registry[dataset_id]
ds.clear_cache()
del ir_datasets.registry[dataset_id]
del ds # clean up this dataset (e.g., close open files)
except KeyError:
pass
if remove_files:
with logger.duration(f'Removing {dataset["id"]} at {str(BASE_PATH/dataset["path"])}'):
shutil.rmtree(BASE_PATH/dataset['path'])
if changed:
with ir_datasets.util.finialized_file(registry_path, 'wt') as fout:
fout.write('[\n')
for item in registry:
fout.write(' ' + json.dumps(item))
if item != registry[-1]:
fout.write(',')
fout.write('\n')
fout.write(']\n')


def iter_local_datasets():
registry_paths = list(BASE_PATH.glob('registry*.json'))
for registry_path in sorted(registry_paths):
with registry_path.open('rt') as fin:
registry = json.load(fin)
changed = False
for dataset in list(registry):
dataset['registry_path'] = str(registry_path)
yield dataset


def _init():
for dataset in iter_local_datasets():
if dataset['id'] in ir_datasets.registry:
new_id = f'local/{dataset["path"]}'
logger.warn(f'Local dataset {repr(dataset["id"])} from {dataset["registry_path"]} already in registry. '
f'Renaming it to {new_id}.')
dataset['id'] = new_id
ds_path = BASE_PATH/dataset['path']
components = []
for etype, val in dataset['provides'].items():
if val is None:
provider_cls = PROVIDERS[EntityType(etype)]
components.append(provider_cls(ds_path/etype))
else:
components.append(ir_datasets.load(val).handler(etype))
ir_datasets.registry.register(dataset['id'], Dataset(*components))

_init()
Loading