Skip to content

Commit a67873e

Browse files
authored
Merge branch 'master' into compathelper/new_version/2025-01-05-01-23-24-852-00256837307
2 parents 4e5ddba + eaef832 commit a67873e

File tree

7 files changed

+131
-18
lines changed

7 files changed

+131
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "1.9.2"
4+
version = "1.9.4"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/eval.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ tree([1 2 3; 4 5 6.], operators)
4747
```
4848

4949
This is possible because when you call `OperatorEnum`, it automatically re-defines
50-
`(::Node)(X)` to call the evaluation operation with the given `operators loaded.
50+
`(::Node)(X)` to call the evaluation operation with the given `operators` loaded.
5151
It also re-defines `print`, `show`, and the various operators, to work with the `Node` type.
5252

5353
!!! warning

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ import .NodeModule:
6969
get_scalar_constants,
7070
set_scalar_constants!
7171
@reexport import .StringsModule: string_tree, print_tree
72-
import .StringsModule: get_op_name
72+
import .StringsModule: get_op_name, get_pretty_op_name
7373
@reexport import .OperatorEnumModule: AbstractOperatorEnum
7474
@reexport import .OperatorEnumConstructionModule:
7575
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!

src/ExpressionAlgebra.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ end
154154
for op in (
155155
:*, :/, :+, :-, :^, :÷, :mod, :log,
156156
:atan, :atand, :copysign, :flipsign,
157-
:&, :|, :, ://, :\,
157+
:&, :|, :, ://, :\, :rem,
158+
:(>), :(<), :(>=), :(<=), :max, :min,
158159
)
159160
@eval @declare_expression_operator Base.$(op) 2
160161
end

src/Strings.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@ using ..UtilsModule: deprecate_varmap
44
using ..OperatorEnumModule: AbstractOperatorEnum
55
using ..NodeModule: AbstractExpressionNode, tree_mapreduce
66

7-
function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
8-
if deg == 1
9-
return vcat(collect("unary_operator["), collect(string(idx)), [']'])
10-
else
11-
return vcat(collect("binary_operator["), collect(string(idx)), [']'])
12-
end
7+
function dispatch_op_name(
8+
::Val{deg}, ::Nothing, idx, pretty::Bool
9+
)::Vector{Char} where {deg}
10+
return vcat(
11+
collect(deg == 1 ? "unary_operator[" : "binary_operator["),
12+
collect(string(idx)),
13+
[']'],
14+
)
1315
end
14-
function dispatch_op_name(::Val{deg}, operators::AbstractOperatorEnum, idx) where {deg}
15-
if deg == 1
16-
return collect(get_op_name(operators.unaops[idx])::String)
16+
function dispatch_op_name(
17+
::Val{deg}, operators::AbstractOperatorEnum, idx, pretty::Bool
18+
) where {deg}
19+
op = if deg == 1
20+
operators.unaops[idx]
1721
else
18-
return collect(get_op_name(operators.binops[idx])::String)
22+
operators.binops[idx]
1923
end
24+
return collect((pretty ? get_pretty_op_name(op) : get_op_name(op))::String)
2025
end
2126

2227
const OP_NAME_CACHE = (; x=Dict{UInt64,String}(), lock=Threads.SpinLock())
@@ -47,6 +52,9 @@ function get_op_name(op::F) where {F}
4752
unlock(OP_NAME_CACHE.lock)
4853
end
4954
end
55+
function get_pretty_op_name(op::F) where {F}
56+
return get_op_name(op)
57+
end
5058

5159
@inline function strip_brackets(s::Vector{Char})::Vector{Char}
5260
if first(s) == '(' && last(s) == ')'
@@ -82,7 +90,7 @@ end
8290

8391
# Vector of chars is faster than strings, so we use that.
8492
function combine_op_with_inputs(op, l, r)::Vector{Char}
85-
if first(op) in ('+', '-', '*', '/', '^', '.')
93+
if first(op) in ('+', '-', '*', '/', '^', '.', '>', '<', '=') || op == "!="
8694
# "(l op r)"
8795
out = ['(']
8896
append!(out, l)
@@ -145,8 +153,9 @@ function string_tree(
145153
raw::Union{Bool,Nothing}=nothing,
146154
varMap=nothing,
147155
)::String where {T,F1<:Function,F2<:Function}
148-
!isnothing(raw) &&
156+
if !isnothing(raw)
149157
Base.depwarn("`raw` is deprecated; use `pretty` instead", :string_tree)
158+
end
150159
pretty = @something(pretty, _not(raw), false)
151160
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
152161
raw_output = tree_mapreduce(
@@ -162,9 +171,9 @@ function string_tree(
162171
end,
163172
let operators = operators
164173
(branch,) -> if branch.degree == 1
165-
dispatch_op_name(Val(1), operators, branch.op)::Vector{Char}
174+
dispatch_op_name(Val(1), operators, branch.op, pretty)::Vector{Char}
166175
else
167-
dispatch_op_name(Val(2), operators, branch.op)::Vector{Char}
176+
dispatch_op_name(Val(2), operators, branch.op, pretty)::Vector{Char}
168177
end
169178
end,
170179
combine_op_with_inputs,

test/test_expressions.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,48 @@ end
454454
@test get_variable_names(new_ex2, nothing) == ["x1"]
455455
@test get_operators(new_ex2, nothing) == new_operators
456456
end
457+
458+
@testitem "New binary operators" begin
459+
using DynamicExpressions
460+
461+
operators = OperatorEnum(;
462+
binary_operators=[+, -, *, /, >, <, >=, <=, max, min, rem],
463+
unary_operators=[sin, cos],
464+
)
465+
x1, x2 = [Node(Float64; feature=i) for i in 1:2]
466+
467+
# Test comparison operators string representation
468+
tree = x1 > x2
469+
@test string(tree) == "x1 > x2"
470+
471+
tree = x1 < x2
472+
@test string(tree) == "x1 < x2"
473+
474+
tree = x1 >= x2
475+
@test string(tree) == "x1 >= x2"
476+
477+
tree = x1 <= x2
478+
@test string(tree) == "x1 <= x2"
479+
480+
# Test max/min operators
481+
tree = max(x1, x2)
482+
X = [1.0 2.0; 3.0 1.0]' # Two points: (1,3) and (2,1)
483+
@test tree(X, operators) [2.0, 3.0]
484+
485+
tree = min(x1, x2)
486+
@test tree(X, operators) [1.0, 1.0]
487+
488+
# Test remainder operator
489+
tree = rem(x1, x2)
490+
X = [5.0 7.0; 3.0 2.0]' # Two points: (5,7) and (3,2)
491+
@test tree(X, operators) [5.0, 1.0]
492+
493+
# Test combinations string representation
494+
tree = max(x1, 2.0) > min(x2, 3.0)
495+
@test string(tree) == "max(x1, 2.0) > min(x2, 3.0)"
496+
497+
# Test with constants
498+
tree = rem(x1, 2.0)
499+
X = [5.0 7.0] # Two points: 5 and 7
500+
@test tree(X, operators) [1.0, 1.0]
501+
end

test/test_print.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,61 @@ end
117117
@test string(tree) == "(k1 * x2) + x3"
118118
empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x)
119119
end
120+
121+
@testset "Test pretty format for operators" begin
122+
# Define a custom operator with different pretty representation
123+
@eval begin
124+
my_pretty_op(x, y) = x + y
125+
DE.get_op_name(::typeof(my_pretty_op)) = "my_pretty_op"
126+
DE.get_pretty_op_name(::typeof(my_pretty_op)) = "pretty_op_two"
127+
end
128+
129+
operators = OperatorEnum(;
130+
default_params...,
131+
binary_operators=(+, *, /, -, my_pretty_op),
132+
unary_operators=(cos, sin),
133+
)
134+
@extend_operators operators
135+
136+
x1, x2 = [Node(; feature=i) for i in 1:2]
137+
138+
# Test default format (not pretty)
139+
tree = my_pretty_op(x1, x2)
140+
@test string_tree(tree, operators) == "my_pretty_op(x1, x2)"
141+
142+
# Test pretty format
143+
@test string_tree(tree, operators; pretty=true) == "pretty_op_two(x1, x2)"
144+
145+
# Test with nested expressions
146+
tree = sin(my_pretty_op(x1, x2))
147+
@test string_tree(tree, operators) == "sin(my_pretty_op(x1, x2))"
148+
@test string_tree(tree, operators; pretty=true) == "sin(pretty_op_two(x1, x2))"
149+
150+
# Test with constants
151+
tree = my_pretty_op(x1, Node(; val=3.14))
152+
@test string_tree(tree, operators) == "my_pretty_op(x1, 3.14)"
153+
@test string_tree(tree, operators; pretty=true) == "pretty_op_two(x1, 3.14)"
154+
155+
# Test that the default implementation of get_pretty_op_name falls back to get_op_name
156+
tree = sin(x1)
157+
@test string_tree(tree, operators) == "sin(x1)"
158+
@test string_tree(tree, operators; pretty=true) == "sin(x1)"
159+
160+
# Test with a unary operator that has a different pretty name
161+
@eval begin
162+
my_unary_op(x) = sin(x)
163+
DE.get_op_name(::typeof(my_unary_op)) = "my_unary_op"
164+
DE.get_pretty_op_name(::typeof(my_unary_op)) = "sine"
165+
end
166+
167+
operators_with_unary = OperatorEnum(;
168+
default_params...,
169+
binary_operators=(+, *, /, -),
170+
unary_operators=(cos, sin, my_unary_op),
171+
)
172+
@extend_operators operators_with_unary
173+
174+
tree = my_unary_op(x1)
175+
@test string_tree(tree, operators_with_unary) == "my_unary_op(x1)"
176+
@test string_tree(tree, operators_with_unary; pretty=true) == "sine(x1)"
177+
end

0 commit comments

Comments
 (0)