Description
I'd like to open a discussion on how we should move forward with implementing a getobs
and nobs
compliant api,
while possibly also simplifying the interface and the maintenance burden.
I think we should move away from the module-based approach and adopt a type-based one. Also could be convenient to have some lean type hierarchy.
Below is an initial proposal
AbstractDatasets
####### src/datasets.jl
abstract type AbstractDataset end
abstract type FileDataset <: AbstractDataset end
abstract type InMemoryDataset <: AbstractDataset end
MNIST Dataset
###### src/vision/mnist.jl
"""
docstring here, also exposing the internal fields of the struct for transparency
"""
struct MNIST <: InMemoryDataset
x # alternative names: `features` or `inputs`
targets # `labels` or j`y`
num_classes # optional
function MNIST(path=nothing; split = :train) # split could be made a mandatory keywork arg
@assert split in [:train, :test]
..........
end
end
LearnBase.getobs(data::MNIST) = (data.x, data.target)
LearnBase.getobs(data::MNIST, idx) = (data.x[:,idx], data.target[idx])
LearnBase.nobs(data::MNIST) = length(data.taget)
.... other stuff ....
Usage
using MLDasets: MNIST
using Flux.
train_data = MNIST(split = :train)
test_data = MNIST(split =:test)
xtrain, ytrain = getobs(train_data)
xtrain, ytrain = train_data # we can add this for convenience
xs, ys = getobs(train_data, 1:10)
xs, ys = train_data[1:10] # we can add this for convenience
train_loader = DataLoader(train_data; batch_size=128)
Transforms
Do we need transformations as part of the datasets?
This is a possible interface that assumes the transform to operate on whatever is returned by getobs
getobs(data::MNIST, idx) = data.transform(data.x[:,idx], data.y[idx])
Data(split = :train, transform = (x, y) -> (random_crop(x), y)
Deprecation Path 1
We can create a deprecation path for the code
using MLDataset: MNIST
xtrain, ytrain = MNIST.traindata(...)
by implementing
function getproperty(data::MNIST, s::Symbol)
if s == :traindata
@warn "deprecated method"
return ....
....
end
Deprecation Path 2
The pattern
using MLDataset.MNIST: traindata
xtrain, ytrain = traindata(...)
instead is more problematic, because assumes a module MNIST exists, but this (deprecated) module would collide with the struct MNIST. A workaround is to call the new struct MNISTDataset
, although I'm not super happy with this long name