Skip to content

Commit

Permalink
rewriting just needs tests
Browse files Browse the repository at this point in the history
  • Loading branch information
quffaro committed Sep 6, 2024
1 parent 7f8597a commit 154e51f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
34 changes: 17 additions & 17 deletions src/symbolictheoryutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,38 @@ macro operator(head, body)
end
(f, types, Theory) = ph(head)

# Passing types to functions requires that we type the signature with ::Type{T}.
# This means that the user would have to write
# my_op(::Type{T1}, ::Type{T2}, ...)
# As a convenience to the user, we allow them to specify the signature using just the types themselves.
# my_op(T1, T2, ...)
sort_types = [:(::Type{$S}) for S in types]
sort_constraints = [:($S<:$Theory) for S in types]
arity = length(sort_types)

# Parse the body for @match or @rule calls. The @match statement parsing is unsophisticated; multiple
# @match statements will be added, and there is currently no validation.
match_calls = []; rule_calls = [];
pb = begin
Expr(:block, args...) => pb.(args)
PatMatch(e) => push!(match_calls, e)
PatRule(e) => push!(rule_calls, e)
s => nothing
end
pb(body);
end; pb(body);

# initialize the result
result = quote end

# DEFINE TYPE INFERENCE IN THE ThDEC SYSTEM
push!(result.args, quote
function $f end; export $f
# construct the function on basic symbolics
push!(result.args, quote
@nospecialize
function $f(args...)
s = promote_symtype($f, args...)
SymbolicUtils.Term{s}($f, [args...])
end
export $f
end)


# we want to feed symtype the generics
push!(result.args, quote
function SymbolicUtils.promote_symtype(::typeof($f), $(sort_types...)) where {$(sort_constraints...)}
Expand All @@ -115,17 +125,7 @@ macro operator(head, body)
end
end)

# CONSTRUCT THE FUNCTION ON BASIC SYMBOLICS
push!(result.args, quote
@nospecialize
function $f(args...)
s = promote_symtype($f, args...)
SymbolicUtils.Term{s}($f, [args...])
end
export $f
end)

push!(result.args, quote $rule_calls end)
push!(result.args, Expr(:tuple, rule_calls...))

return esc(result)
end
Expand Down
9 changes: 7 additions & 2 deletions test/decasymbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2}
@test Term(a * b) == Mult(Term[Var(:a), Var(:b)])
@test Term du) == App2(:₁₁, Var(), Var(:du))

# test promoting types
@test promote_symtype(d, u) == PrimalForm{1, :X, 2}
@test promote_symtype(+, a, b) == Scalar
@test promote_symtype(, u, u) == PrimalForm{0, :X, 2}
@test promote_symtype(, u, ω) == PrimalForm{1, :X, 2}

# test composition
@test promote_symtype(d d, u) == PrimalForm{2, :X, 2}

end

@testset "Operator definition" begin
Expand All @@ -58,8 +61,8 @@ end
PatScalar(_) => error("Argument of type $S is invalid")
PatForm(_) => promote_symtype(★ d d, S)
end
@rule ~~x::isForm0 => (d((d(x))))
@rule ~~x::isForm1 => (d((d(x)))) + d((d((x))))
@rule ~x::isForm0 => (d((d(~x))))
@rule ~x::isForm1 => (d((d(~x)))) + d((d((~x))))
end;
# TODO rewriting not working atm
# del_expand = Chain(del_expand0, del_expand1)
Expand All @@ -68,6 +71,8 @@ end
@test symtype((u)) == PrimalForm{0, :X ,2}

@test_broken promote_symtype(Δ, [u,v])

@test del_expand_0(u) == (d((d(u))))
end

@testset "Conversion" begin
Expand Down

0 comments on commit 154e51f

Please sign in to comment.