Skip to content

Commit 01f01ac

Browse files
committed
feat: allow partial updates to with_metadata
1 parent dde9291 commit 01f01ac

File tree

4 files changed

+43
-13
lines changed

4 files changed

+43
-13
lines changed

src/Expression.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,25 @@ import ..SimplifyModule: combine_operators, simplify_tree!
3131
struct Metadata{NT<:NamedTuple}
3232
_data::NT
3333
end
34-
_data(x::Metadata) = getfield(x, :_data)
34+
unpack_metadata(x) = x # Fallback for when the user doesn't use the Metadata type
35+
unpack_metadata(x::Metadata) = getfield(x, :_data)
3536

36-
Base.propertynames(x::Metadata) = propertynames(_data(x))
37-
@unstable @inline Base.getproperty(x::Metadata, f::Symbol) = getproperty(_data(x), f)
38-
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", _data(x), ")")
37+
Base.propertynames(x::Metadata) = propertynames(unpack_metadata(x))
38+
@unstable @inline function Base.getproperty(x::Metadata, f::Symbol)
39+
return getproperty(unpack_metadata(x), f)
40+
end
41+
Base.show(io::IO, x::Metadata) = print(io, "Metadata(", unpack_metadata(x), ")")
3942
@inline _copy(x) = copy(x)
4043
@inline _copy(x::NamedTuple) = copy_named_tuple(x)
4144
@inline _copy(x::Nothing) = nothing
4245
@inline function copy_named_tuple(nt::NamedTuple)
4346
return NamedTuple{keys(nt)}(map(_copy, values(nt)))
4447
end
4548
@inline function Base.copy(metadata::Metadata)
46-
return Metadata(_copy(_data(metadata)))
49+
return Metadata(_copy(unpack_metadata(metadata)))
4750
end
48-
@inline Base.:(==)(x::Metadata, y::Metadata) = _data(x) == _data(y)
49-
@inline Base.hash(x::Metadata, h::UInt) = hash(_data(x), h)
51+
@inline Base.:(==)(x::Metadata, y::Metadata) = unpack_metadata(x) == unpack_metadata(y)
52+
@inline Base.hash(x::Metadata, h::UInt) = hash(unpack_metadata(x), h)
5053

5154
"""
5255
AbstractExpression{T,N}
@@ -216,7 +219,9 @@ end
216219
Create a new expression based on `ex` but with a different `metadata`.
217220
"""
218221
function with_metadata(ex::AbstractExpression; metadata...)
219-
return with_metadata(ex, Metadata((; metadata...)))
222+
return with_metadata(
223+
ex, Metadata((; unpack_metadata(get_metadata(ex))..., metadata...))
224+
)
220225
end
221226
function with_metadata(ex::AbstractExpression, metadata::Metadata)
222227
return constructorof(typeof(ex))(get_contents(ex), metadata)

src/ParametricExpression.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
55

66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
8-
using ..ExpressionModule: AbstractExpression, Metadata, with_contents, with_metadata
8+
using ..ExpressionModule:
9+
AbstractExpression, Metadata, with_contents, with_metadata, unpack_metadata
910
using ..ChainRulesModule: NodeTangent
1011

1112
import ..NodeModule:
@@ -63,7 +64,6 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T}
6364
return n
6465
end
6566
end
66-
@inline _data(x::Metadata) = getfield(x, :_data)
6767

6868
"""
6969
ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
@@ -79,7 +79,9 @@ struct ParametricExpression{
7979
metadata::Metadata{D}
8080

8181
function ParametricExpression(tree::ParametricNode, metadata::Metadata)
82-
return new{eltype(tree),typeof(tree),typeof(_data(metadata))}(tree, metadata)
82+
return new{eltype(tree),typeof(tree),typeof(unpack_metadata(metadata))}(
83+
tree, metadata
84+
)
8385
end
8486
end
8587
function ParametricExpression(

src/StructuredExpression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import ..ExpressionModule:
1616
with_contents,
1717
Metadata,
1818
_copy,
19-
_data,
19+
unpack_metadata,
2020
default_node_type,
2121
node_type,
2222
get_scalar_constants,
@@ -114,7 +114,7 @@ constructorof(::Type{<:StructuredExpression}) = StructuredExpression
114114
function Base.copy(e::AbstractStructuredExpression)
115115
ts = get_contents(e)
116116
meta = get_metadata(e)
117-
meta_inner = _data(meta)
117+
meta_inner = unpack_metadata(meta)
118118
copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts)))
119119
keys_except_structure = filter(!=(:structure), keys(meta_inner))
120120
copy_metadata = (;

test/test_expressions.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,26 @@ end
413413

414414
#literate_end
415415
end
416+
417+
@testitem "Expression with_metadata partial updates" begin
418+
using DynamicExpressions
419+
using DynamicExpressions: get_operators, get_metadata, with_metadata, get_variable_names
420+
421+
# Create an expression with initial metadata
422+
ex = @parse_expression(
423+
x1 + 1.5,
424+
operators = OperatorEnum(; binary_operators=[+, *]),
425+
variable_names = ["x1"]
426+
)
427+
428+
# Update only the variable_names, keeping the original operators
429+
new_ex = with_metadata(ex; variable_names=["y1"])
430+
@test get_variable_names(new_ex, nothing) == ["y1"]
431+
@test get_operators(new_ex, nothing) == get_operators(ex, nothing)
432+
433+
# Update only the operators, keeping the original variable_names
434+
new_operators = OperatorEnum(; binary_operators=[+])
435+
new_ex2 = with_metadata(ex; operators=new_operators)
436+
@test get_variable_names(new_ex2, nothing) == ["x1"]
437+
@test get_operators(new_ex2, nothing) == new_operators
438+
end

0 commit comments

Comments
 (0)