Skip to content

Commit 8d2797e

Browse files
gca30MilesCranmer
andauthored
Apply suggestions from code review
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent da1886f commit 8d2797e

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/Evaluate.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function eval_tree_array(
8181
error("Please load the LoopVectorization.jl package to use this feature.")
8282
end
8383
if (v_turbo isa Val{true} || v_bumper isa Val{true}) && !(T <: Number)
84-
error("Bumper feature only works with numbers")
84+
error("Bumper and LoopVectorization features are only compatible with numeric element types")
8585
end
8686
if v_bumper isa Val{true}
8787
return bumper_eval_tree_array(tree, cX, operators, v_turbo)
@@ -97,7 +97,7 @@ function eval_tree_array(
9797
operators::OperatorEnum;
9898
kws...
9999
) where {T}
100-
return eval_tree_array(tree, reshape(cX, (size(cX)[1], 1))::AbstractMatrix{T}, operators; kws...)
100+
return eval_tree_array(tree, reshape(cX, (size(cX, 1), 1)), operators; kws...)
101101
end
102102

103103
function eval_tree_array(

test/test_non_number_eval_tree_array.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct SVM{T}
88
scalar :: T
99
vector :: Vector{T}
1010
matrix :: Matrix{T}
11-
SVM{T}() where {T} = new(Int8(0), zero(T), T[], T[;;])
11+
SVM{T}() where {T} = new(Int8(0), zero(T), T[], Array{T}(undef, 0, 0))
1212
SVM{T}(scalar :: W) where {T, W <: Number} = new(Int8(0), Base.convert(T, scalar), T[], T[;;])
1313
SVM{T}(vector :: Vector{W}) where {T, W <: Number} = new(Int8(1), zero(T), Vector{T}(vector) , T[;;])
1414
SVM{T}(matrix :: Matrix{W}) where {T, W <: Number} = new(Int8(2), zero(T), T[], Matrix{T}(matrix))
@@ -29,8 +29,7 @@ end
2929
function Base.:(==)(x::SVM{T}, y::SVM{T}) where T
3030
if x.dims !== y.dims
3131
return false
32-
end
33-
if x.dims == 0
32+
elseif x.dims == 0
3433
return x.scalar == y.scalar
3534
elseif val.dims == 1
3635
return x.vector == y.vector
@@ -61,7 +60,7 @@ Base.invokelatest(() -> begin
6160
@test !hasmethod(a, Tuple{Node{SVM{Float32}}, Node{SVM{Float32}}})
6261

6362
tree = a(Node{SVM{Float64}}(; feature=1), SVM{Float64}(3.0))
64-
results = tree([SVM{Float64}(1.0);; SVM{Float64}(2.0);; SVM{Float64}(3.0)])
63+
results = tree([SVM{Float64}(1.0) SVM{Float64}(2.0) SVM{Float64}(3.0)])
6564
@test results == [SVM{Float64}(4), SVM{Float64}(5), SVM{Float64}(6)]
6665

6766

0 commit comments

Comments
 (0)