Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
9798f58
eval_tree_array (generic) reshapes output
grezde Jun 24, 2024
656d232
added type interface so non-generic eval works with any type
grezde Jun 24, 2024
4b6d5e8
added oher types to OperatorEnum
grezde Jun 25, 2024
c7ad68c
fixed most tests
grezde Jun 25, 2024
a5efcd4
added example as test
grezde Jun 25, 2024
5468282
removed unnecessary lines in this branch
grezde Jun 25, 2024
8587599
reset Project.toml
grezde Jun 25, 2024
da1886f
Update src/Evaluate.jl
gca30 Jun 26, 2024
8d2797e
Apply suggestions from code review
gca30 Jun 26, 2024
a3d14e9
applied formatting, all other suggestions from pull request
gca30 Jun 26, 2024
225ba4d
generalized Optim, removed outdated tests
gca30 Jul 3, 2024
a8003ce
added bradcast operator support to expand_operators and showing expre…
gca30 Jul 4, 2024
08b20f6
Merge branch 'master' into pr/gca30/85
MilesCranmer Jul 4, 2024
ae404c3
remove old parsing test
MilesCranmer Jul 4, 2024
c788b38
append_number_constants no longer uses push
gca30 Jul 5, 2024
41c6d37
reformat
gca30 Jul 5, 2024
60dc72e
applied merge
gca30 Jul 5, 2024
7fbe0e6
Apply suggestions from code review
gca30 Jul 5, 2024
85d8b25
reformat
gca30 Jul 5, 2024
bdf4ad0
type interface suggestions
gca30 Jul 5, 2024
a638f84
reformat
gca30 Jul 5, 2024
1213786
rename scalar utility functions
MilesCranmer Jul 5, 2024
f4ab375
specialize type in `get_scalar_constants`
MilesCranmer Jul 5, 2024
150552d
fix tests
MilesCranmer Jul 5, 2024
16fa343
simplify use of interface
MilesCranmer Jul 5, 2024
81fd233
rename `index_constants` to `index_constant_nodes`
MilesCranmer Jul 5, 2024
7b62bd8
ci: compat with deprecated function names
MilesCranmer Jul 6, 2024
1c0d935
deps: remove unused TestItems.jl in main package
MilesCranmer Jul 6, 2024
f2217bd
style: formatting
MilesCranmer Jul 6, 2024
0284229
style: move TypeInterface imports to top of import list
MilesCranmer Jul 6, 2024
2189c7f
fix: update precompile function names
MilesCranmer Jul 6, 2024
70c46fc
style: import all extensions at once for parallel precompilation
MilesCranmer Jul 6, 2024
f8e5acf
style: update name to `set_scalar_constants!` in benchmarks
MilesCranmer Jul 6, 2024
0ad3c93
ci: add extra benchmark for parametric nodes
MilesCranmer Jul 6, 2024
308c5d1
style: clean up ParametricExpression internal methods
MilesCranmer Jul 6, 2024
0b8bd43
feat: add `ValueInterface` to formalize interface
MilesCranmer Jul 6, 2024
6fae896
refactor: rename to `ValueInterface`
MilesCranmer Jul 6, 2024
9347092
refactor: simplify `Max2Tensor`
MilesCranmer Jul 6, 2024
398b640
fix: equality check between dims
MilesCranmer Jul 6, 2024
571518c
fix: name of ValueInterfaceModule
MilesCranmer Jul 6, 2024
bebbe18
refactor: simplify `Max2Tensor` further
MilesCranmer Jul 6, 2024
e2e9b21
fix: `ValueInterface` checker of packing
MilesCranmer Jul 6, 2024
1b1a674
test: full `ValueInterface` for `Max2Tensor`
MilesCranmer Jul 6, 2024
5fa74fc
refactor: `Max2Tensor` to more general `DynamicTensor`
MilesCranmer Jul 6, 2024
455e80f
test: fix more generic `DynamicTensor`
MilesCranmer Jul 6, 2024
6e3db19
refactor: clean up use of `is_valid` throughout library
MilesCranmer Jul 6, 2024
1c878f2
refactor: propagate inbounds to unpacking and packing
MilesCranmer Jul 6, 2024
0c8aab9
refactor: clean up broadcasted operators
MilesCranmer Jul 7, 2024
5ce69c5
feat: warn for `BroadcastFunction`
MilesCranmer Jul 7, 2024
b162f75
feat: use safer `lock` syntax
MilesCranmer Jul 7, 2024
1b576fe
feat: only warn if defining helper functions
MilesCranmer Jul 7, 2024
7ff927b
fix: mark `@unstable`
MilesCranmer Jul 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using DispatchDoctor: @stable, @unstable

@stable default_mode = "disable" begin
include("Utils.jl")
include("TypeInterface.jl")
include("ExtensionInterface.jl")
include("OperatorEnum.jl")
include("Node.jl")
Expand Down
131 changes: 81 additions & 50 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@ import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: is_bad_array, fill_similar, counttuple, ResultOk
import ..NodeUtilsModule: is_constant
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded
import ..TypeInterfaceModule: is_valid, is_valid_array

const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15

macro return_on_check(val, X)
:(
if !isfinite($(esc(val)))
if !is_valid($(esc(val)))
return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false)
end
)
end

macro return_on_nonfinite_array(array)
:(
if is_bad_array($(esc(array)))
if !is_valid_array($(esc(array)))
return $(ResultOk)($(esc(array)), false)
end
)
Expand Down Expand Up @@ -69,7 +70,7 @@ function eval_tree_array(
operators::OperatorEnum;
turbo::Union{Bool,Val}=Val(false),
bumper::Union{Bool,Val}=Val(false),
) where {T<:Number}
) where {T}
v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false))
v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false))
if v_turbo isa Val{true} || v_bumper isa Val{true}
Expand All @@ -79,20 +80,33 @@ function eval_tree_array(
_is_loopvectorization_loaded(0) ||
error("Please load the LoopVectorization.jl package to use this feature.")
end
if (v_turbo isa Val{true} || v_turbo isa Val{true}) && !(T <: Number)
error("Bumper feature only works with numbers")
end
if v_bumper isa Val{true}
return bumper_eval_tree_array(tree, cX, operators, v_turbo)
end

result = _eval_tree_array(tree, cX, operators, v_turbo)
return (result.x, result.ok && !is_bad_array(result.x))
return (result.x, result.ok && is_valid_array(result.x))
end

function eval_tree_array(
tree::AbstractExpressionNode{T},
cX::AbstractVector{T},
operators::OperatorEnum;
kws...
) where {T}
return eval_tree_array(tree, reshape(cX, (size(cX)[1], 1))::AbstractMatrix{T}, operators; kws...)
end

function eval_tree_array(
tree::AbstractExpressionNode{T1},
cX::AbstractMatrix{T2},
operators::OperatorEnum;
turbo::Union{Bool,Val}=Val(false),
bumper::Union{Bool,Val}=Val(false),
) where {T1<:Number,T2<:Number}
) where {T1,T2}
T = promote_type(T1, T2)
@warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)."
tree = convert(constructorof(typeof(tree)){T}, tree)
Expand All @@ -108,7 +122,7 @@ function _eval_tree_array(
cX::AbstractMatrix{T},
operators::OperatorEnum,
::Val{turbo},
)::ResultOk where {T<:Number,turbo}
)::ResultOk where {T,turbo}
# First, we see if there are only constants in the tree - meaning
# we can just return the constant result.
if tree.degree == 0
Expand All @@ -131,7 +145,7 @@ end

function deg2_eval(
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{false}
)::ResultOk where {T<:Number,F}
)::ResultOk where {T,F}
@inbounds @simd for j in eachindex(cumulator_l)
x = op(cumulator_l[j], cumulator_r[j])::T
cumulator_l[j] = x
Expand All @@ -141,7 +155,7 @@ end

function deg1_eval(
cumulator::AbstractVector{T}, op::F, ::Val{false}
)::ResultOk where {T<:Number,F}
)::ResultOk where {T,F}
@inbounds @simd for j in eachindex(cumulator)
x = op(cumulator[j])::T
cumulator[j] = x
Expand All @@ -151,7 +165,7 @@ end

function deg0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}
)::ResultOk where {T<:Number}
)::ResultOk where {T}
if tree.constant
return ResultOk(fill_similar(tree.val, cX, axes(cX, 2)), true)
else
Expand All @@ -165,7 +179,7 @@ end
op_idx::Integer,
operators::OperatorEnum,
::Val{turbo},
) where {T<:Number,turbo}
) where {T,turbo}
nbin = get_nbin(operators)
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
if long_compilation_time
Expand Down Expand Up @@ -219,7 +233,7 @@ end
op_idx::Integer,
operators::OperatorEnum,
::Val{turbo},
) where {T<:Number,turbo}
) where {T,turbo}
nuna = get_nuna(operators)
long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
if long_compilation_time
Expand Down Expand Up @@ -267,7 +281,7 @@ end
l_op_idx::Integer,
binops,
::Val{turbo},
) where {T<:Number,F,turbo}
) where {T,F,turbo}
nbin = counttuple(binops)
# (Note this is only called from dispatch_deg1_eval, which has already
# checked for long compilation times, so we don't need to check here)
Expand All @@ -288,7 +302,7 @@ end
l_op_idx::Integer,
unaops,
::Val{turbo},
)::ResultOk where {T<:Number,F,turbo}
)::ResultOk where {T,F,turbo}
nuna = counttuple(unaops)
quote
Base.Cartesian.@nif(
Expand All @@ -303,7 +317,7 @@ end

function deg1_l2_ll0_lr0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{false}
) where {T<:Number,F,F2}
) where {T,F,F2}
if tree.l.l.constant && tree.l.r.constant
val_ll = tree.l.l.val
val_lr = tree.l.r.val
Expand All @@ -321,7 +335,7 @@ function deg1_l2_ll0_lr0_eval(
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(val_ll, cX[feature_lr, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
Expand All @@ -332,7 +346,7 @@ function deg1_l2_ll0_lr0_eval(
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], val_lr)::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
Expand All @@ -342,7 +356,7 @@ function deg1_l2_ll0_lr0_eval(
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
Expand All @@ -352,7 +366,7 @@ end
# op(op2(x)) for x variable or constant
function deg1_l1_ll0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{false}
) where {T<:Number,F,F2}
) where {T,F,F2}
if tree.l.l.constant
val_ll = tree.l.l.val
@return_on_check val_ll cX
Expand All @@ -366,7 +380,7 @@ function deg1_l1_ll0_eval(
cumulator = similar(cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j])::T
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
cumulator[j] = x
end
return ResultOk(cumulator, true)
Expand All @@ -376,7 +390,7 @@ end
# op(x, y) for x and y variable/constant
function deg2_l0_r0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{false}
) where {T<:Number,F}
) where {T,F}
if tree.l.constant && tree.r.constant
val_l = tree.l.val
@return_on_check val_l cX
Expand Down Expand Up @@ -424,7 +438,7 @@ function deg2_l0_eval(
cX::AbstractArray{T},
op::F,
::Val{false},
) where {T<:Number,F}
) where {T,F}
if tree.l.constant
val = tree.l.val
@return_on_check val cX
Expand All @@ -450,7 +464,7 @@ function deg2_r0_eval(
cX::AbstractArray{T},
op::F,
::Val{false},
) where {T<:Number,F}
) where {T,F}
if tree.r.constant
val = tree.r.val
@return_on_check val cX
Expand All @@ -470,15 +484,15 @@ function deg2_r0_eval(
end

"""
dispatch_constant_tree(tree::AbstractExpressionNode{T}, operators::OperatorEnum) where {T<:Number}
dispatch_constant_tree(tree::AbstractExpressionNode{T}, operators::OperatorEnum) where {T}

Evaluate a tree which is assumed to not contain any variable nodes. This
gives better performance, as we do not need to perform computation
over an entire array when the values are all the same.
"""
@generated function dispatch_constant_tree(
tree::AbstractExpressionNode{T}, operators::OperatorEnum
) where {T<:Number}
) where {T}
nuna = get_nuna(operators)
nbin = get_nbin(operators)
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
Expand Down Expand Up @@ -524,29 +538,29 @@ over an entire array when the values are all the same.
end
end

@inline function deg0_eval_constant(tree::AbstractExpressionNode{T}) where {T<:Number}
@inline function deg0_eval_constant(tree::AbstractExpressionNode{T}) where {T}
output = tree.val
return ResultOk([output], true)::ResultOk{Vector{T}}
end

function deg1_eval_constant(
tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum
) where {T<:Number,F}
) where {T,F}
result = dispatch_constant_tree(tree.l, operators)
!result.ok && return result
output = op(result.x[])::T
return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
return ResultOk([output], is_valid(output))::ResultOk{Vector{T}}
end

function deg2_eval_constant(
tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum
) where {T<:Number,F}
) where {T,F}
cumulator = dispatch_constant_tree(tree.l, operators)
!cumulator.ok && return cumulator
result_r = dispatch_constant_tree(tree.r, operators)
!result_r.ok && return result_r
output = op(cumulator.x[], result_r.x[])::T
return ResultOk([output], isfinite(output))::ResultOk{Vector{T}}
return ResultOk([output], is_valid(output))::ResultOk{Vector{T}}
end

"""
Expand Down Expand Up @@ -611,6 +625,8 @@ function deg2_diff_eval(
return ResultOk(out, all(isfinite, out))
end

get_lower_array_type(T,N) = N==1 ? T : AbstractArray{T,N-1}

"""
eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)

Expand Down Expand Up @@ -660,15 +676,17 @@ function eval(current_node)
that it was not defined for.
"""
@unstable function eval_tree_array(
tree::AbstractExpressionNode,
cX::AbstractArray,
tree::AbstractExpressionNode{T1},
cX::AbstractArray{T2, N},
operators::GenericOperatorEnum;
throw_errors::Bool=true,
)
!throw_errors && return _eval_tree_array_generic(tree, cX, operators, Val(false))
) #=::Tuple{get_lower_array_type(T1, N), Bool}=# where {T1,T2,N}
try
return _eval_tree_array_generic(tree, cX, operators, Val(true))
catch e
if !throw_errors
return nothing, false
end
tree_s = string_tree(tree, operators)
error_msg = "Failed to evaluate tree $(tree_s)."
if isa(e, MethodError)
Expand All @@ -686,49 +704,62 @@ end
tree::AbstractExpressionNode{T1},
cX::AbstractArray{T2,N},
operators::GenericOperatorEnum,
::Val{throw_errors},
) where {T1,T2,N,throw_errors}
::Val{throw_errors}
) #= :: Tuple{get_lower_array_type(T1, N), Bool} =# where {T1,T2,N,throw_errors}
if tree.degree == 0
if tree.constant
return (tree.val::T1), true
if N == 1
return (tree.val::T1), true
else
return fill(tree.val::T1, size(cX)[2:N]), true
end
else
if N == 1
return cX[tree.feature], true
return (cX[tree.feature]), true
else
return selectdim(cX, 1, tree.feature), true
end
end
elseif tree.degree == 1
return deg1_eval_generic(
tree, cX, operators.unaops[tree.op], operators, Val(throw_errors)
)
) #=::Tuple{get_lower_array_type(T1, N), Bool}=#
else
return deg2_eval_generic(
tree, cX, operators.binops[tree.op], operators, Val(throw_errors)
)
) #+::Tuple{get_lower_array_type(T1, N), Bool}=#
end
end

@unstable function deg1_eval_generic(
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
) where {F,throw_errors}
left, complete = eval_tree_array(tree.l, cX, operators)
tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N}, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
) #= :: Tuple{get_lower_array_type(T1, N), Bool} =# where {F,T1,T2,N,throw_errors}
left, complete = _eval_tree_array_generic(tree.l, cX, operators, Val(throw_errors))
!throw_errors && !complete && return nothing, false
!throw_errors && !hasmethod(op, Tuple{typeof(left)}) && return nothing, false
return op(left), true
!throw_errors && !hasmethod(op, N==1 ? Tuple{typeof(left)} : Tuple{eltype(left)}) && return nothing, false
if N == 1
return op(left), true
else
return op.(left), true
end

end

@unstable function deg2_eval_generic(
tree, cX, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
) where {F,throw_errors}
left, complete = eval_tree_array(tree.l, cX, operators)
tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N}, op::F, operators::GenericOperatorEnum, ::Val{throw_errors}
) #= :: Tuple{get_lower_array_type(T1, N), Bool} =# where {F,T1,T2,N,throw_errors}
left, complete = _eval_tree_array_generic(tree.l, cX, operators, Val(throw_errors))
!throw_errors && !complete && return nothing, false
right, complete = eval_tree_array(tree.r, cX, operators)
right, complete = _eval_tree_array_generic(tree.r, cX, operators, Val(throw_errors))
!throw_errors && !complete && return nothing, false
!throw_errors &&
!hasmethod(op, Tuple{typeof(left),typeof(right)}) &&
!hasmethod(op, N == 1 ? Tuple{typeof(left),typeof(right)} : Tuple{eltype(left),eltype(right)}) &&
return nothing, false
return op(left, right), true
if N == 1
return op(left, right), true
else
return op.(left, right), true
end
end

end
Loading