Skip to content

Commit

Permalink
Dataset registry (dmlc#68)
Browse files Browse the repository at this point in the history
* Improve word embeddings evaluation docs, tests and code style

* Decrease default batch-size for word embedding evaluation to avoid OOM

* Add dataset and keyword registry

* Enable dataset registry for word embedding evaluation datasets

* Update word embedding evaluation script based on dataset registry

* Scope code in script behind __main__

Fix redefined-outer-name lint error

* Enable Dataset registry for all Datasets

* Workaround Py3 inspect.getargspec deprecation

* Add logging option to word embedding evaluation script

* Add __all__
  • Loading branch information
leezu authored and szha committed Apr 23, 2018
1 parent b4f142a commit 5c6e6e8
Show file tree
Hide file tree
Showing 13 changed files with 535 additions and 219 deletions.
51 changes: 51 additions & 0 deletions gluonnlp/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,14 @@
'wiki.zu': ('wiki.zu.npz',
'fc9ce07d5d0c49a3c86cf1b26056ada58f9404ca')}

GOOGLEANALOGY_CATEGORIES = [
'capital-common-countries', 'capital-world', 'currency', 'city-in-state',
'family', 'gram1-adjective-to-adverb', 'gram2-opposite',
'gram3-comparative', 'gram4-superlative', 'gram5-present-participle',
'gram6-nationality-adjective', 'gram7-past-tense', 'gram8-plural',
'gram9-plural-verbs'
]

BATS_CHECKSUMS = \
{'BATS_3.0/1_Inflectional_morphology/I01 [noun - plural_reg].txt':
'cfcba2835edf81abf11b84defd2f4daa3ca0b0bf',
Expand Down Expand Up @@ -796,6 +804,49 @@
'BATS_3.0/4_Lexicographic_semantics/L10 [antonyms - binary].txt':
'3cde2f2c2a0606777b8d7d11d099f316416a7224'}

BATS_CATEGORIES = {
'I01': '[noun - plural_reg]',
'I02': '[noun - plural_irreg]',
'I03': '[adj - comparative]',
'I04': '[adj - superlative]',
'I05': '[verb_inf - 3pSg]',
'I06': '[verb_inf - Ving]',
'I07': '[verb_inf - Ved]',
'I08': '[verb_Ving - 3pSg]',
'I09': '[verb_Ving - Ved]',
'I10': '[verb_3pSg - Ved]',
'D01': '[noun+less_reg]',
'D02': '[un+adj_reg]',
'D03': '[adj+ly_reg]',
'D04': '[over+adj_reg]',
'D05': '[adj+ness_reg]',
'D06': '[re+verb_reg]',
'D07': '[verb+able_reg]',
'D08': '[verb+er_irreg]',
'D09': '[verb+tion_irreg]',
'D10': '[verb+ment_irreg]',
'E01': '[country - capital]',
'E02': '[country - language]',
'E03': '[UK_city - county]',
'E04': '[name - nationality]',
'E05': '[name - occupation]',
'E06': '[animal - young]',
'E07': '[animal - sound]',
'E08': '[animal - shelter]',
'E09': '[things - color]',
'E10': '[male - female]',
'L01': '[hypernyms - animals]',
'L02': '[hypernyms - misc]',
'L03': '[hyponyms - misc]',
'L04': '[meronyms - substance]',
'L05': '[meronyms - member]',
'L06': '[meronyms - part]',
'L07': '[synonyms - intensity]',
'L08': '[synonyms - exact]',
'L09': '[antonyms - gradable]',
'L10': '[antonyms - binary]'
}

SEMEVAL17_CHECKSUMS = \
{'SemEval17-Task2/README.txt':
'ad02d4c22fff8a39c9e89a92ba449ec78750af6b',
Expand Down
5 changes: 4 additions & 1 deletion gluonnlp/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from .utils import *

from .registry import *

from .transforms import *

from .sampler import *
Expand All @@ -42,4 +44,5 @@

__all__ = (utils.__all__ + transforms.__all__ + sampler.__all__ +
dataset.__all__ + language_model.__all__ + sentiment.__all__ +
word_embedding_evaluation.__all__ + conll.__all__ + translation.__all__)
word_embedding_evaluation.__all__ + conll.__all__ +
translation.__all__ + registry.__all__)
11 changes: 11 additions & 0 deletions gluonnlp/data/conll.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from mxnet.gluon.utils import download, check_sha1

from .. import _constants as C
from .registry import register


class _CoNLLSequenceTagging(SimpleDataset):
Expand Down Expand Up @@ -97,6 +98,7 @@ def _process_iter(self, line_iter):
return samples


@register(segment=['train', 'test'])
class CoNLL2000(_CoNLLSequenceTagging):
"""CoNLL2000 Part-of-speech (POS) tagging and chunking joint task dataset.
Expand All @@ -122,6 +124,8 @@ def __init__(self, segment='train', root=os.path.join('~', '.mxnet', 'datasets',
base_url = 'http://www.clips.uantwerpen.be/conll2000/chunking/'
codec = 'utf-8'


@register(segment=['train', 'testa', 'testb'], part=[1, 2, 3])
class CoNLL2001(_CoNLLSequenceTagging):
"""CoNLL2001 Clause Identification dataset.
Expand Down Expand Up @@ -176,6 +180,8 @@ def _get_data_file_hash(self):
available_segments)
return [self._data_file[self._part-1][self._segment]]


@register(segment=['train', 'testa', 'testb'], lang=['esp', 'ned'])
class CoNLL2002(_CoNLLSequenceTagging):
"""CoNLL2002 Named Entity Recognition (NER) task dataset.
Expand Down Expand Up @@ -226,6 +232,8 @@ def _get_data_file_hash(self):
available_segments)
return [self._data_file[self._lang][self._segment]]


@register(segment=['train', 'dev', 'test'])
class CoNLL2004(_CoNLLSequenceTagging):
"""CoNLL2004 Semantic Role Labeling (SRL) task dataset.
Expand Down Expand Up @@ -297,6 +305,9 @@ def _extract_archive(self):
shutil.copy(fn, root)
shutil.rmtree(os.path.join(root, 'conll04st-release'), ignore_errors=True)


@register(segment=['train', 'dev', 'test'],
lang=list(C.UD21_DATA_FILE_SHA1.keys()))
class UniversalDependencies21(_CoNLLSequenceTagging):
"""Universal dependencies tree banks.
Expand Down
3 changes: 3 additions & 0 deletions gluonnlp/data/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from .. import _constants as C
from .dataset import LanguageModelDataset
from .registry import register


class _WikiText(LanguageModelDataset):
Expand Down Expand Up @@ -64,6 +65,7 @@ def _get_data(self):
return path


@register(segment=['train', 'val', 'test'])
class WikiText2(_WikiText):
"""WikiText-2 word-level dataset for language modeling, from Salesforce research.
Expand Down Expand Up @@ -98,6 +100,7 @@ def __init__(self, segment='train', skip_empty=True, bos=None, eos=C.EOS_TOKEN,
super(WikiText2, self).__init__('wikitext-2', segment, bos, eos, skip_empty, root)


@register(segment=['train', 'val', 'test'])
class WikiText103(_WikiText):
"""WikiText-103 word-level dataset for language modeling, from Salesforce research.
Expand Down
146 changes: 146 additions & 0 deletions gluonnlp/data/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A registry for datasets
The registry makes it simple to construct a dataset given its name.
"""
__all__ = ['register', 'create', 'list_datasets']

import inspect

from mxnet import registry
from mxnet.gluon.data import Dataset

_REGSITRY_NAME_KWARGS = {}


def register(class_=None, **kwargs):
"""Registers a dataset with segment specific hyperparameters.
When passing keyword arguments to `register`, they are checked to be valid
keyword arguments for the registered Dataset class constructor and are
saved in the registry. Registered keyword arguments can be retrieved with
the `list_datasets` function.
All arguments that result in creation of separate datasets should be
registered. Examples are datasets divided in different segments or
categories, or datasets containing multiple languages.
Once registered, an instance can be created by calling
:func:`~gluonnlp.data.create` with the class name.
Parameters
----------
**kwargs : list or tuple of allowed argument values
For each keyword argument, it's value must be a list or tuple of the
allowed argument values.
Examples
--------
>>> @gluonnlp.data.register(segment=['train', 'test', 'dev'])
... class MyDataset(Dataset):
... def __init__(self, segment='train'):
... pass
>>> my_dataset = gluonnlp.embedding.create('MyDataset')
>>> print(type(my_dataset))
<class '__main__.MyDataset'>
"""

def _real_register(class_):
# Assert that the passed kwargs are meaningful
for kwarg_name, values in kwargs.items():
try:
real_args = inspect.getfullargspec(class_).args
except AttributeError:
# pylint: disable=deprecated-method
real_args = inspect.getargspec(class_.__init__).args

if not kwarg_name in real_args:
raise RuntimeError(
('{} is not a valid argument for {}. '
'Only valid arguments can be registered.').format(
kwarg_name, class_.__name__))

if not isinstance(values, (list, tuple)):
raise RuntimeError(('{} should be a list of '
'valid arguments for {}. ').format(
values, kwarg_name))

# Save the kwargs associated with this class_
_REGSITRY_NAME_KWARGS[class_] = kwargs

register_ = registry.get_register_func(Dataset, 'dataset')
return register_(class_)

if class_ is not None:
# Decorator was called without arguments
return _real_register(class_)

return _real_register


def create(name, **kwargs):
"""Creates an instance of a registered dataset.
Parameters
----------
name : str
The dataset name (case-insensitive).
Returns
-------
An instance of :class:`mxnet.gluon.data.Dataset` constructed with the
keyword arguments passed to the create function.
"""
create_ = registry.get_create_func(Dataset, 'dataset')
return create_(name, **kwargs)


def list_datasets(name=None):
"""Get valid datasets and registered parameters.
Parameters
----------
name : str or None, default None
Return names and registered parameters of registered datasets. If name
is specified, only registered parameters of the respective dataset are
returned.
Returns
-------
dict:
A dict of all the valid keyword parameters names for the specified
dataset. If name is set to None, returns a dict mapping each valid name
to its respective keyword parameter dict. The valid names can be
plugged in `gluonnlp.model.word_evaluation_model.create(name)`.
"""
reg = registry.get_registry(Dataset)

if name is not None:
class_ = reg[name.lower()]
return _REGSITRY_NAME_KWARGS[class_]
else:
return {
dataset_name: _REGSITRY_NAME_KWARGS[class_]
for dataset_name, class_ in registry.get_registry(Dataset).items()
}
2 changes: 2 additions & 0 deletions gluonnlp/data/sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

from mxnet.gluon.data import SimpleDataset
from mxnet.gluon.utils import download, check_sha1, _get_repo_file_url
from .registry import register


@register(segment=['train', 'test', 'unsup'])
class IMDB(SimpleDataset):
"""IMDB reviews for sentiment analysis.
Expand Down
4 changes: 4 additions & 0 deletions gluonnlp/data/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from .dataset import TextLineDataset
from ..vocab import Vocab
from .registry import register


def _get_pair_key(src_lang, tgt_lang):
Expand Down Expand Up @@ -148,6 +149,7 @@ def tgt_vocab(self):
return self._tgt_vocab


@register(segment=['train', 'val', 'test'])
class IWSLT2015(_TranslationDataset):
"""Preprocessed IWSLT English-Vietnamese Translation Dataset.
Expand Down Expand Up @@ -190,6 +192,8 @@ def __init__(self, segment='train', src_lang='en', tgt_lang='vi',
tgt_lang=tgt_lang, root=root)


@register(segment=['train', 'newtest2012', 'newtest2013', 'newtest2014', \
'newtest2015', 'newtest2016'])
class WMT2016BPE(_TranslationDataset):
"""Preprocessed Translation Corpus of the WMT2016 Evaluation Campaign.
Expand Down
Loading

0 comments on commit 5c6e6e8

Please sign in to comment.