Skip to content

Improve indexing to support table behavior #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
238 changes: 143 additions & 95 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,13 @@ abstract type AbstractModel end

Base.parent(m::AbstractModel) = getfield(m, :parent)
setparent(m::AbstractModel, newparent) = @set m.parent = newparent
function setparent!(m::AbstractModel, newparent)
setfield!(m, :parent, newparent)
return m
end

params(m::AbstractModel) = params(parent(m))
stripparams(m::AbstractModel) = stripparams(parent(m))
function update(x::T, values) where {T<:AbstractModel}
hasfield(T, :parent) || _updatenotdefined(T)
setparent(x, update(parent(x), values))
end

@noinline _update_methoderror(T) = error("Interface method `update` is not defined for $T")

paramfieldnames(m) = Flatten.fieldnameflatten(parent(m), SELECT, IGNORE)
paramparenttypes(m) = Flatten.metaflatten(parent(m), _fieldparentbasetype, SELECT, IGNORE)
_fieldparentbasetype(T, ::Type{Val{N}}) where N = component(T)

"""
component(::Type{T}) where T
Expand All @@ -104,100 +98,179 @@ implementation simply uses `T.name.wrapper` which is the `UnionAll` type corresp
the unparameterized type name of `T`.
"""
component(::Type{T}) where T = T.name.wrapper
component(T, ::Type{Val{N}}) where N = component(T)
paramfieldnames(m::AbstractModel) = Flatten.fieldnameflatten(parent(m), SELECT, IGNORE)
paramcomponents(m::AbstractModel) = Flatten.metaflatten(parent(m), component, SELECT, IGNORE)

function Base.show(io::IO, mime::MIME"text/plain", m::AbstractModel)
show(io, mime, typeof(m))
println(io, " with parent object of type: \n")
show(io, mime, typeof(parent(m)))
println(io, "\n\n")
printparams(io::IO, m)
end
printparams(m) = printparams(stdout, m)
function printparams(io::IO, m::AbstractModel)
if length(m) > 0
println(io, "Parameters:")
PrettyTables.pretty_table(io, m; header=[keys(m)...])
end
end

# Tuple-like indexing and iterables interface

# It may seem expensive always calling `param`, but flattening the
# object occurs once at compile-time, and should have very little cost here.
Base.IndexStyle(::Type{<:AbstractModel}) = IndexCartesian()
Base.length(m::AbstractModel) = length(params(m))
Base.size(m::AbstractModel) = (length(params(m)),)
Base.size(m::AbstractModel) = (length(params(m)), length(keys(m)))
Base.first(m::AbstractModel) = first(params(m))
Base.last(m::AbstractModel) = last(params(m))
Base.firstindex(m::AbstractModel) = 1
Base.lastindex(m::AbstractModel) = length(params(m))
Base.getindex(m::AbstractModel, i) = getindex(params(m), i)
Base.iterate(m::AbstractModel) = (first(params(m)), 1)
Base.iterate(m::AbstractModel) = (first(params(m)), 2)
Base.iterate(m::AbstractModel, s) = s > length(m) ? nothing : (params(m)[s], s + 1)
Base.eachcol(m::AbstractModel) = (m[col] for col in keys(m))
Base.eachrow(m::AbstractModel) = m

# Vector methods
Base.collect(m::AbstractModel) = collect(m.val)
Base.collect(m::AbstractModel) = collect(m[:val])
Base.vec(m::AbstractModel) = collect(m)
Base.Array(m::AbstractModel) = vec(m)

# Dict methods - data as columns
Base.haskey(m::AbstractModel, key::Symbol) = key in keys(m)
Base.keys(m::AbstractModel) = _keys(params(m), m)

@inline function Base.setindex!(m::AbstractModel, x, nm::Symbol)
if nm == :component
erorr("cannot set :component index")
elseif nm == :fieldname
erorr("cannot set :fieldname index")
else
newparent = if nm in keys(m)
_setindex(parent(m), Tuple(x), nm)
else
_addindex(parent(m), Tuple(x), nm)
end
setparent!(m, newparent)
_keys(params::Tuple, ::AbstractModel) = (:component, :fieldname, keys(first(params))...)
_keys(::Tuple{}, ::AbstractModel) = ()
_isreserved(key::Symbol) = key == :component || key == :fieldname
@inline _enumerate(tup::NTuple{N,Any}) where N = map(tuple, tuple(1:N...), tup) # type stable tuple enumeration

# Indexing kernels
const RowIndexer = Union{Integer,Colon,AbstractVector{<:Integer}}
@inline _getindex(ps::Tuple{Vararg{<:Param}}, i::RowIndexer) = _getindex(ps, i, :)
@inline _getindex(ps::Tuple{Vararg{<:Param}}, col::Symbol) = _getindex(ps, :, col)
@inline _getindex(ps::Tuple{Vararg{<:Param}}, i::RowIndexer, ::Colon) = ps[i]
@inline _getindex(ps::Tuple{Vararg{<:Param}}, i::RowIndexer, col::Symbol) = map(p -> p[col], ps[i])
@inline _getindex(ps::Tuple{Vararg{<:Param}}, i::AbstractVector{Bool}, col::Symbol) = _getindex(ps, findall(i), col)
@inline _setindex(obj, xs, ::Colon, cols) = _setindex(obj, xs, 1:length(params(obj)), cols)
@inline _setindex(obj, xs, idxs::AbstractVector{Bool}, cols::AbstractVector) = _setindex(obj, xs, findall(idxs), cols)
@inline _setindex(obj, xs, idxs::AbstractVector{Bool}, ::Type{Val{col}}) where col = _setindex(obj, xs, findall(idxs), Val{col})
@inline @generated _setindex(ps::Tuple{Vararg{<:Param}}, x, i::Integer, ::Type{Val{col}}) where col = :(@set ps[i].$col = x)
@inline function _setindex(obj, x, i::Integer, ::Type{Val{col}}) where col
ps = params(obj)
newps = _setindex(ps, x, i, Val{col})
return Flatten.reconstruct(obj, newps, SELECT, IGNORE)
end
@inline function _setindex(obj, xs, ::Colon, ::Type{Val{col}}) where col
# handle special case for ::Colon (all indices) where we can be type stable
ps = params(obj)
newps = map(_enumerate(ps)) do (i,p)
_setindex((p,), xs[i], 1, Val{col})[1]
end
return Flatten.reconstruct(obj, newps, SELECT, IGNORE)
end
# TODO do this with lenses
@inline function _setindex(obj, xs::Tuple, nm::Symbol)
lens = Setfield.PropertyLens{nm}()
newparams = map(params(obj), xs) do par, x
Param(Setfield.set(parent(par), lens, x))
@inline function _setindex(obj, xs, idxs::AbstractVector{<:Integer}, ::Type{Val{col}}) where col
ps = params(obj)
for i in 1:length(idxs)
ps = _setindex(ps, xs[i], idxs[i], Val{col})
end
Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
return Flatten.reconstruct(obj, ps, SELECT, IGNORE)
end
@inline function _addindex(obj, xs::Tuple, nm::Symbol)
newparams = map(params(obj), xs) do par, x
Param((; parent(par)..., (nm => x,)...))
@inline function _addindex(obj, xs, ::Type{Val{col}}) where col
newparams = map(params(obj), xs) do p, x
Param((; parent(p)..., col => x))
end
Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
return Flatten.reconstruct(obj, newparams, SELECT, IGNORE)
end

_keys(params::Tuple, m::AbstractModel) = (:component, :fieldname, keys(first(params))...)
_keys(params::Tuple{}, m::AbstractModel) = ()

@inline function Base.getindex(m::AbstractModel, nm::Symbol)
if nm == :component
paramparenttypes(m)
elseif nm == :fieldname
paramfieldnames(m)
# Indexing interface
@inline Base.getindex(m::AbstractModel, col::Symbol) = getindex(m, :, col)
@inline Base.getindex(m::AbstractModel, i::RowIndexer) = getindex(m, i, :)
@inline Base.getindex(m::AbstractModel, ::Colon, ::Colon) = m
@inline function Base.getindex(m::AbstractModel, i::RowIndexer, col)
return if col == :component
paramcomponents(m)[i]
elseif col == :fieldname
paramfieldnames(m)[i]
else
map(p -> getindex(p, nm), params(m))
_getindex(params(m), i, col)
end
end

function Base.show(io::IO, mime::MIME"text/plain", m::AbstractModel)
show(io, mime, typeof(m))
println(io, " with parent object of type: \n")
show(io, mime, typeof(parent(m)))
println(io, "\n\n")
printparams(io::IO, m)
@inline Base.setindex(m::AbstractModel, xs, col::Union{Symbol,Type{<:Val}}) = Base.setindex(m, xs, :, col)
@inline Base.setindex(m::AbstractModel, xs, i::RowIndexer) = Base.setindex(m, xs, i, :)
@inline Base.setindex(m::AbstractModel, xs, i::RowIndexer, col::Symbol) = Base.setindex(m, xs, i, Val{col})
@inline Base.setindex(m::AbstractModel, xs, i::RowIndexer, ::Type{Val{col}}) where col = Base.setindex(m, xs, collect(i), Val{col})
@inline Base.setindex(m::AbstractModel, xs, i::RowIndexer, ::Colon) = Base.setindex(m, xs, collect(i), filter(!_isreserved, keys(m)))
@inline Base.setindex(m::AbstractModel, xs, i::Integer, ::Colon) = Base.setindex(m, xs, i, filter(!_isreserved, keys(m)))
@inline Base.setindex(m::AbstractModel, xs, i::Integer, ::Type{Val{col}}) where col = _setindex(m, xs, i, Val{col})
@inline function Base.setindex(m::AbstractModel, xs, i::RowIndexer, cols::Union{Tuple,AbstractVector})
return foldl(cols; init=m) do m, col
Base.setindex(m, Tables.getcolumn(xs, col), i, col)
end
end

printparams(m) = printparams(stdout, m)
function printparams(io::IO, m::AbstractModel)
if length(m) > 0
println(io, "Parameters:")
PrettyTables.pretty_table(io, m, [keys(m)...])
@inline function Base.setindex(m::AbstractModel, xs, i::AbstractVector{<:Integer}, ::Type{Val{col}}) where col
@assert !_isreserved(col) "column name :$col is reserved and cannot be modified"
@assert col ∈ keys(m) "column $col does not exist"
return _setindex(m, xs, i, Val{col})
end
@inline function Base.setindex(m::AbstractModel, xs, ::Colon, ::Type{Val{col}}) where col
@assert !_isreserved(col) "column name :$col is reserved and cannot be modified"
return if col ∈ keys(m)
_setindex(m, xs, :, Val{col})
else
_addindex(m, xs, Val{col})
end
end
@inline Base.setindex!(m::AbstractModel, xs, col::Union{Symbol,Type{Val}}) = setindex!(m, xs, :, col)
@inline Base.setindex!(m::AbstractModel, xs, i::RowIndexer) = setindex!(m, xs, i, :)
@inline Base.setindex!(m::AbstractModel, xs, i::RowIndexer, col) = setparent!(m, parent(Base.setindex(m, xs, i, col)))

setparent!(m::AbstractModel, newparent) = setfield!(m, :parent, newparent)
# helper function for evaluating indexing predicates
@inline _indices(m, rule) = findall(map(rule, map((c, n, p) -> setparent(p, (; parent(p)..., :component => c, :fieldname => n)), paramcomponents(m), paramfieldnames(m), params(m))))
@inline _indices(::Any, ::Nothing) = Colon()

update!(m::AbstractModel, vals::AbstractVector{<:AbstractParam}) = update!(m, Tuple(vals))
function update!(params::Tuple{<:AbstractParam,Vararg{<:AbstractParam}})
setparent!(m, Flatten.reconstruct(parent(m), params, SELECT, IGNORE))
end
function update!(m::AbstractModel, table)
cols = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname)))
for col in cols
setindex!(m, Tables.getcolumn(table, col), col)
end
m
# Update (value)
"""
update(obj, xs::Union{AbstractVector,Tuple}, rule=nothing)

Updates the `val` field of `Param`s at rows selected by `rule` in `obj` with the values in `xs`. Type stable and
allocation free when all rows are selected (i.e. `rule=nothing`).
"""
@inline update(obj, xs::Union{AbstractVector,Tuple}, rule=nothing) = _setindex(obj, xs, _indices(obj, rule), Val{:val})
@inline update(m::AbstractModel, xs::Union{AbstractVector,Tuple}, rule=nothing) = Base.setindex(m, xs, _indices(m, rule), Val{:val})
# Update (table)
"""
update(m::AbstractModel, table, rule=nothing)

Updates the columns of `Param`s at rows selected by `rule` in `m` with the values in `table`, which must implement
the `Tables.jl` interface.
"""
@inline update(m::AbstractModel, table, rule=nothing) = Base.setindex(m, table, _indices(m, rule), filter(!_isreserved, Tables.columnnames(table)))
# Update helpers
"""
update!(m::AbstractModel, xs, rule=nothing)

Mutating version of `update` which sets the parent of `m` to the updated value.
"""
@inline update!(m::AbstractModel, xs, rule=nothing) = setparent!(m, parent(update(m, xs, rule)))
"""
update(f, m::AbstractModel, rule=nothing)
update!(f, m::AbstractModel, rule=nothing)

Updates `Param`s in `m` matching the predicate `rule` with values produced by function `f`.
`rule` should be a function of the form `rule(::Param)::Bool` and `f` should have the
form `f(::Param)::T` where `T` is a vector of values or table.
"""
update!(f, m::AbstractModel, rule=nothing) = setparent!(m, parent(update(f, m, rule)))
function update(f, m::AbstractModel, rule=nothing)
# case 1: use :val field for 1-D vector
astable(xs::AbstractVector{<:Number}) = Tables.table(xs; header=[:val])
astable(xs) = Tables.columns(xs)
ps = params(m)
idxs = _indices(m, rule)
xs = astable(collect(map(f, ps[idxs])))
return Base.setindex(m, xs, idxs, Tables.columnnames(xs))
end

"""
Expand Down Expand Up @@ -227,30 +300,6 @@ mutable struct Model <: AbstractModel
end
Model(m::AbstractModel) = Model(parent(m))

@inline @generated function _update_params(ps::P, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}}
expr = Expr(:tuple)
for i in 1:N
expr_i = :(Param(NamedTuple{keys(ps[$i])}((values[$i], Base.tail(parent(ps[$i]))...))))
push!(expr.args, expr_i)
end
return expr
end

update(x, values) = _update(ModelParameters.params(x), x, values)
@inline function _update(p::P, x, values::Union{<:AbstractVector,<:Tuple}) where {N,P<:NTuple{N,Param}}
@assert length(values) == N "values length must match the number of parameters"
newparams = _update_params(p, values)
Flatten.reconstruct(x, newparams, SELECT, IGNORE)
end
@inline function _update(p::P, x, table) where {N,P<:NTuple{N,Param}}
@assert size(table, 1) == N "number of rows must match the number of parameters"
cols = (c for c in Tables.columnnames(table) if !(c in (:component, :fieldname)))
newparams = map(p, tuple(1:N...)) do param, i
Param(NamedTuple{keys(param)}(map(name -> Tables.getcolumn(table, name)[i], cols)))
end
Flatten.reconstruct(x, newparams, SELECT, IGNORE)
end

"""
StaticModel(x)

Expand All @@ -275,7 +324,6 @@ StaticModel(m::AbstractModel) = StaticModel(parent(m))

# Model Utils

_expandpars(x) = Flatten.reconstruct(parent, _expandkeys(parent), SELECT, IGNORE)
# Expand all Params to have the same keys, filling with `nothing`
# This probably will allocate due to `union` returning `Vector`
function _expandkeys(x)
Expand Down
5 changes: 3 additions & 2 deletions src/param.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Base.values(p::AbstractParam) = values(parent(p))
@inline Base.getproperty(p::AbstractParam, x::Symbol) = getproperty(parent(p), x)
@inline Base.get(p::AbstractParam, key::Symbol, default) = get(parent(p), key, default)
@inline Base.getindex(p::AbstractParam, i) = getindex(parent(p), i)
@inline Base.getindex(p::AbstractParam, i::Integer) = getindex(parent(p), i)


# AbstractNumber interface
Expand Down Expand Up @@ -79,13 +80,13 @@ end
Param(val; kwargs...) = Param((; val=val, kwargs...))
Param(; kwargs...) = Param((; kwargs...))

setparent(::P, newparent) where P<:AbstractParam = ConstructionBase.constructorof(P)(newparent)

Base.parent(p::Param) = getfield(p, :parent)

# Methods for objects that hold params
params(x) = Flatten.flatten(x, SELECT, IGNORE)
stripparams(x) = hasparam(x) ? Flatten.reconstruct(x, withunits(x), SELECT, IGNORE) : x


# Utils
hasparam(obj) = length(params(obj)) > 0

Expand Down
3 changes: 1 addition & 2 deletions src/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ Tables.columnaccess(::Type{<:AbstractModel}) = true
Tables.columns(m::AbstractModel) = m
Tables.getcolumn(m::AbstractModel, nm::Symbol) = collect(getindex(m, nm))
Tables.getcolumn(m::AbstractModel, i::Int) = collect(getindex(m, i))
Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol) where T =
collect(getindex(m, nm))
Tables.getcolumn(m::AbstractModel, ::Type{T}, col::Int, nm::Symbol) where T = collect(getindex(m, nm))
Loading