Skip to content

Commit

Permalink
add comments and sanity check (apache#8901)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Jan 22, 2018
1 parent 45f9609 commit 3412c44
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 55 deletions.
56 changes: 53 additions & 3 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,53 @@ def __len__(self):
raise NotImplementedError

def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
Parameters
----------
fn : callable
A transformer function that takes a sample as input and
returns the transformed sample.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
trans = _LazyTransformDataset(self, fn)
if lazy:
return trans
return SimpleDataset([i for i in trans])

def transform_first(self, fn, lazy=True):
"""Returns a new dataset with the first element of each sample
transformed by the transformer function `fn`.
This is useful, for example, when you only want to transform data
while keeping label as is.
Parameters
----------
fn : callable
A transformer function that takes the first elemtn of a sample
as input and returns the transformed element.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
def base_fn(x, *args):
if args:
return (fn(x),) + args
Expand All @@ -55,6 +96,13 @@ def base_fn(x, *args):


class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
Parameters
----------
data : dataset-like object
Any object that implements `len()` and `[]`.
"""
def __init__(self, data):
self._data = data

Expand All @@ -66,6 +114,7 @@ def __getitem__(self, idx):


class _LazyTransformDataset(Dataset):
"""Lazily transformed dataset."""
def __init__(self, data, fn):
self._data = data
self._fn = fn
Expand All @@ -81,13 +130,14 @@ def __getitem__(self, idx):


class ArrayDataset(Dataset):
"""A dataset of multiple arrays.
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
The i-th sample is `(x1[i], x2[i], ...)`.
The i-th sample is defined as `(x1[i], x2[i], ...)`.
Parameters
----------
*args : one or more arrays
*args : one or more dataset-like objects
The data arrays.
"""
def __init__(self, *args):
Expand Down
Loading

0 comments on commit 3412c44

Please sign in to comment.