Skip to content

Commit

Permalink
added tests for decasymbolic, but needs wrinkles ironed out. musical …
Browse files Browse the repository at this point in the history
…isos also given placeholder nameof methods
  • Loading branch information
quffaro committed Aug 18, 2024
1 parent c2d64b5 commit 5928cc7
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 64 deletions.
37 changes: 10 additions & 27 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,23 @@
module DiagrammaticEquations

export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
# Deca
op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D,
recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
## collages
Collage, collate,
## composition
oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram,
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
resolve_overloads!, replace_names!,
apply_inference_rule_op1!, apply_inference_rule_op2!,
transfer_parents!, transfer_children!,
unique_lits!,
## language
@decapode, Term, parse_decapode, term, Eq, DecaExpr,
# ~~~~~
Plus, AppCirc1, Var, Tan, App1, App2,
## visualization
to_graphviz_property_graph, typename, draw_composition,
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!
DerivOp, append_dot, normalize_unicode,

## intertypes
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, DecaExpr, Plus, AppCirc1, Var, Tan, App1, App2, Eq

using Catlab
using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, , , , associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
using Catlab.CategoricalAlgebra
import Catlab.CategoricalAlgebra:
using Catlab.WiringDiagrams
using Catlab.WiringDiagrams.DirectedWiringDiagrams
using Catlab.ACSetInterface
using MLStyle
import Unicode
using Reexport

## TODO:
## generate schema from a _theory_
Expand All @@ -64,6 +44,9 @@ include("learn/Learn.jl")
include("ThDEC.jl")
include("decasymbolic.jl")

using .Deca
@reexport using .ThDEC
@reexport using .SymbolicUtilsInterop
@reexport using .Deca


end
13 changes: 13 additions & 0 deletions src/ThDEC.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module ThDEC

using MLStyle

import Base: +, -, *
Expand Down Expand Up @@ -142,6 +143,7 @@ function ★(s::Sort)
@match s begin
Scalar() => throw(SortError("Cannot take Hodge star of a scalar"))
Form(i, isdual) => Form(2 - i, !isdual)
VF(isdual) => throw(SortError("Cannot take the Hodge star of a vector field"))
end
end

Expand Down Expand Up @@ -172,13 +174,24 @@ function ♯(s::Sort)
end
# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf.

# TODO
function Base.nameof(::typeof(♯), s)
Symbol("♯s")
end


function (s::Sort)
@match s begin
VF(true) => PrimalForm(1)
_ => throw(SortError("Can only apply ♭ to dual vector fields"))
end
end

# TODO
function Base.nameof(::typeof(♭), s)
Symbol("♭s")
end

# OTHER

function ♭♯(s::Sort)
Expand Down
99 changes: 62 additions & 37 deletions src/decasymbolic.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
module SymbolicUtilInterop
module SymbolicUtilsInterop

using ..ThDEC
using MLStyle
import ..ThDEC: Sort, dim, isdual
using ..decapodes
using SymbolicUtils

using SymbolicUtils: Symbolic, BasicSymbolic

# ##########################
# DECType
#
# Type necessary for symbolic utils
# ##########################

# define DECType as a Number. Necessary for SymbolicUtils
abstract type DECType <: Number end

"""
i: dimension: 0,1,2, etc.
d: duality: true = dual, false = primal
"""
struct FormT{i,d} <: DECType
end
struct FormT{i,d} <: DECType end

struct VFieldT{d} <: DECType
end
struct VFieldT{d} <: DECType end

dim(::Type{<:FormT{d}}) where {d} = d
isdual(::Type{FormT{i,d}}) where {i,d} = d
Expand All @@ -35,46 +41,41 @@ export PrimalVFT
const DualVFT = VFieldT{true}
export DualVFT

function Sort(::Type{FormT{i,d}}) where {i,d}
Form(i, d)
end
# convert Real to DecType
Sort(::Type{<:Real}) = Scalar()

function Number(f::Form)
FormT{dim(f),isdual(f)}
end
# convert Real to ThDEC
Sort(::Real) = Scalar()

function Sort(::Type{VFieldT{d}}) where {d}
VField(d)
end
# convert DECType to ThDEC
Sort(::Type{FormT{i,d}}) where {i,d} = Form(i, d)

function Number(v::VField)
VFieldT{isdual(v)}
end
# convert DECType to ThDEC
Sort(::Type{VFieldT{d}}) where {d} = VField(d)

function Sort(::Type{<:Real})
Scalar()
end
Sort(::BasicSymbolic{T}) where {T} = Sort(T)

function Number(s::Scalar)
Real
end
# convert Form to DECType
Number(f::Form) = FormT{dim(f), isdual(f)}

function Sort(::BasicSymbolic{T}) where {T}
Sort(T)
end
# convert VField to DECType
Number(v::VField) = VFieldT{isdual(v)}

function Sort(::Real)
Scalar()
end
# convert number to real
Number(s::Scalar) = Real

# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term
unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-]
for unop in unop_dec
@eval begin
@nospecialize
function ThDEC.$unop(
v::BasicSymbolic{T}
) where {T<:DECType}
# convert the DECType to ThDEC to type check
s = ThDEC.$unop(Sort(T))
# the resulting type is converted back to DECType
# the resulting term has the operation has its head and `v` as its args.
SymbolicUtils.Term{Number(s)}(ThDEC.$unop, [v])
end
end
Expand All @@ -91,6 +92,7 @@ for binop in binop_dec
s = ThDEC.$binop(Sort(T1), Sort(T2))
SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w])
end
export $binop

@nospecialize
function ThDEC.$binop(
Expand All @@ -100,6 +102,7 @@ for binop in binop_dec
s = ThDEC.$binop(Sort(T1), Sort(T2))
SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w])
end
export $binop

@nospecialize
function ThDEC.$binop(
Expand All @@ -109,19 +112,25 @@ for binop in binop_dec
s = ThDEC.$binop(Sort(T1), Sort(T2))
SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w])
end
export $binop
end
end

struct Equation{E}
# name collision with decapodes.Equation
struct DecaEquation{E}
lhs::E
rhs::E
end
export DecaEquation

# a struct carry the symbolic variables and their equations
struct DecaSymbolic
vars::Vector{Symbolic}
equations::Vector{Equation{Symbolic}}
equations::Vector{DecaEquation{Symbolic}}
end
export DecaSymbolic

# BasicSymbolic -> DecaExpr
function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
if SymbolicUtils.issym(t)
decapodes.Var(nameof(t))
Expand All @@ -146,25 +155,39 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
end
end

function decapodes.Term(x::Real)
decapodes.Lit(Symbol(x))
end
decapodes.Term(x::Real) = decapodes.Lit(Symbol(x))

function decapodes.DecaExpr(d::DecaSymbolic)
context = map(d.vars) do var
decapodes.Judgement(nameof(var), nameof(Sort(var)), :I)
# TODO changed :I to :X to make tests pass, but discussion
# needed on handling spaces
decapodes.Judgement(nameof(var), nameof(Sort(var)), :X)
end
equations = map(d.equations) do eq
decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs))
end
decapodes.DecaExpr(context, equations)
end

"""
Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC
Example:
```
a = @syms a::Real
context = Dict(:a => Scalar(), :u => PrimalForm(0))
SymbolicUtils.BasicSymbolic(context, Term(a))
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term)
@match t begin
Var(name) => SymbolicUtils.Sym{Number(context[name])}(name)
Lit(v) => Meta.parse(string(v)) # YOLO
Lit(v) => Meta.parse(string(v)) # TODO no YOLO
# see heat_eq test: eqs had AppCirc1, but this returns
# App1(f, App1(...)
AppCirc1(fs, arg) => foldr(
# panics with constants like :k
# see test/language.jl
(f, x) -> ThDEC.OPERATOR_LOOKUP[f](x),
fs;
init=BasicSymbolic(context, arg)
Expand All @@ -178,15 +201,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Te
end

function DecaSymbolic(d::decapodes.DecaExpr)
# associates each var to its sort...
context = map(d.context) do j
j.var => ThDEC.SORT_LOOKUP[j.dim]
end
# ... which we then produce a vector of symbolic vars
vars = map(context) do (v, s)
SymbolicUtils.Sym{Number(s)}(v)
end
context = Dict{Symbol,Sort}(context)
eqs = map(d.equations) do eq
Equation{Symbolic}(BasicSymbolic(context, eq.lhs), BasicSymbolic(context, eq.rhs))
DecaEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...)
end
DecaSymbolic(vars, eqs)
end
Expand Down
85 changes: 85 additions & 0 deletions test/decasymbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using Test
#
using DiagrammaticEquations
using DiagrammaticEquations.ThDEC
using DiagrammaticEquations.decapodes
#
using SymbolicUtils

@testset "ThDEC Signature checking" begin
@test Scalar() + Scalar() == Scalar()
end

# load up some variable variables and expressions
a, b = @syms a::Real b::Real
u, v = @syms u::PrimalFormT{0} du::PrimalFormT{1}
ω, η = @syms ω::PrimalFormT{1} η::DualFormT{2}
ϕ, ψ = @syms ϕ::PrimalVFT ψ::DualVFT

expr_scalar_addition = a + b
expr_primal_wedge = ThDEC.:(ω, du)

@testset "Term Construction" begin

# test conversion to underlying type
@test Sort(a) == Scalar()
@test Sort(u) == PrimalForm(0)
@test Sort(ω) == PrimalForm(1)
@test Sort(η) == DualForm(2)
@test Sort(ϕ) == PrimalVF()
@test Sort(ψ) == DualVF()

@test_throws ThDEC.SortError ThDEC.(u)

# test unary operator conversion to decaexpr
@test Term(1) == DiagrammaticEquations.decapodes.Lit(Symbol("1"))
@test Term(a) == Var(:a)
@test Term(ThDEC.∂ₜ(u)) == Tan(Var(:u))
@test Term(ThDEC.(ω)) == App1(:★₁, Var())
@test Term(ThDEC.(ψ)) == App1(:♭s, Var())
# @test Term(DiagrammaticEquations.ThDEC.♯(du))

@test_throws ThDEC.SortError ThDEC.(ϕ)

# test binary operator conversion to decaexpr
@test Term(a + b) == Plus(Term[Var(:a), Var(:b)])
@test Term(a * b) == DiagrammaticEquations.decapodes.Mult(Term[Var(:a), Var(:b)])
@test Term(ThDEC.:(ω, du)) == App2(:₁₁, Var(), Var(:du))

end

@testset "Moving between DecaExpr and DecaSymbolic" begin end

context = Dict(:a => Scalar(), :b => Scalar()
,:u => PrimalForm(0), :du => PrimalForm(1))

js = [Judgement(:u, :Form0, :X)
,Judgement(:∂ₜu, :Form0, :X)
,Judgement(:Δu, :Form0, :X)]
eqs = [Eq(Var(:∂ₜu), AppCirc1([:₂⁻¹, :d₁, :₁, :d₀], Var(:u)))
,Eq(Tan(Var(:u)), Var(:∂ₜu))]
heat_eq = DecaExpr(js, eqs)


symb_heat_eq = DecaSymbolic(heat_eq)
deca_expr = DecaExpr(symb_heat_eq)

@test js == deca_expr.context

# eqs in the left has AppCirc1[vector, term]
# deca_expr.equations on the right has nested App1
# expected behavior is that nested AppCirc1 is preserved
@test_broken eqs == deca_expr.equations

# copied from test/language
js = [Judgement(:C, :Form0, :X),
Judgement(:Ċ₁, :Form0, :X),
Judgement(:Ċ₂, :Form0, :X)
]
# TODO: Do we need to handle the fact that all the functions are parameterized by a space?
eqs = [Eq(Var(:Ċ₁), AppCirc1([:₀⁻¹, :dual_d₁, :₁, :k, :d₀], Var(:C))),
Eq(Var(:Ċ₂), AppCirc1([:₀⁻¹, :dual_d₁, :₁, :d₀], Var(:C))),
Eq(Tan(Var(:C)), Plus([Var(:Ċ₁), Var(:Ċ₂)]))
]
diffusion_d = DecaExpr(js, eqs)

0 comments on commit 5928cc7

Please sign in to comment.