Skip to content

[WIP] first prototype of default value rules implemented #749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,48 @@
# 2. Dictionary
# 3. Callback: takes arguments Dictionary × Number of elements matched
#

function matcher(val::Any)
iscall(val) && return term_matcher(val)
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
if iscall(val)
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val)
# else return a normal term matcher
else
return term_matcher_constructor(val)
end
end

function literal_matcher(next, data, bindings)
# car data is the first element of data
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
end
end

function matcher(slot::Slot)
function slot_matcher(next, data, bindings)
!islist(data) && return
!islist(data) && return nothing
val = get(bindings, slot.name, nothing)
# if slot name already is in bindings, check if it matches
if val !== nothing
if isequal(val, car(data))
return next(bindings, 1)
end
else
if slot.predicate(car(data))
next(assoc(bindings, slot.name, car(data)), 1)
end
# elseif the first element of data matches the slot predicate, add it to bindings and call next
elseif slot.predicate(car(data))
next(assoc(bindings, slot.name, car(data)), 1)
end
end
end

# this is called only when defslot_term_matcher finds the operation and tries
# to match it, so no default value used. So the same function as slot_matcher
# can be used
function matcher(defslot::DefSlot)
matcher(Slot(defslot.name, defslot.predicate))
end

# returns n == offset, 0 if failed
function trymatchexpr(data, value, n)
if !islist(value)
Expand Down Expand Up @@ -84,13 +103,73 @@ function matcher(segment::Segment)
end
end

function term_matcher(term)
function term_matcher_constructor(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)

function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
# explenation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end

loop(car(data), bindings, matchers) # Try to eat exactly one term
end
end

# creates a matcher for a term containing a defslot, such as:
# (~x + ...complicated pattern...) * ~!y
# normal part (can bee a tree) operation defslot part

# defslot_term_matcher works like this:
# checks wether data starts with the default operation.
# if yes (1): continues like term_matcher
# if no checks wether data matches the normal part
# if no returns nothing, rule is not applied
# if yes (2): adds the pair (default value name, default value) to the found bindings and
# calls the success function like term_matcher would do

function defslot_term_matcher_constructor(term)
a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term

defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
defslot = a[defslot_index]

function defslot_term_matcher(success, data, bindings)
# if data is not a list, return nothing
!islist(data) && return nothing
!iscall(car(data)) && return nothing
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
if !iscall(car(data)) || (istree(car(data)) && nameof(operation(car(data))) != defslot.operation)
other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part

# checks wether it matches the normal part
# <-----------------(2)------------------------------->
bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)

if bindings === nothing
return nothing
end
return success(bindings, 1)
end

# (1)
function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
Expand Down
99 changes: 82 additions & 17 deletions src/rule.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

@inline alwaystrue(x) = true

# Matcher patterns with Slot and Segment
# Matcher patterns with Slot, DefSlot and Segment

# matches one term
# syntax: ~x
Expand All @@ -16,6 +16,79 @@ Base.isequal(s1::Slot, s2::Slot) = s1.name == s2.name

Base.show(io::IO, s::Slot) = (print(io, "~"); print(io, s.name))

# for when the slot is a symbol, like `~x`
makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s))

# for when the slot is an expression, like `~x::predicate`
function makeslot(s::Expr, keys)
if !(s.head == :(::))
error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function")
end

name = s.args[1]

push!(keys, name)
:(Slot($(QuoteNode(name)), $(esc(s.args[2]))))
end






# matches one term with built in default value.
# syntax: ~!x
# Example usage:
# (~!x + ~y) can match (a + b) but also just "a" and x takes default value of zero.
# (~!x)*(~y) can match a*b but also just "a", and x takes default value of one.
# (~x + ~y)^(~!z) can match (a + b)^c but also just "a + b", and z takes default value of one.
# only these three operations are supported for default values.

struct DefSlot{P, O}
name::Symbol
predicate::P
operation::O
defaultValue::Real
end

# operation | default
# + | 0
# * | 1
# ^ | 1
function defaultValOfCall(call)
if call == :+
return 0
elseif call == :*
return 1
elseif call == :^
return 1
end
# else no default value for this call
error("You can use default slots only with +, * and ^, but you tried with: $call")
end

DefSlot(s) = DefSlot(s, alwaystrue, nothing, 0)
Base.isequal(s1::DefSlot, s2::DefSlot) = s1.name == s2.name
Base.show(io::IO, s::DefSlot) = (print(io, "~!"); print(io, s.name))

makeDefSlot(s::Symbol, keys, op) = (push!(keys, s); DefSlot(s, alwaystrue, op, defaultValOfCall(op)))

function makeDefSlot(s::Expr, keys, op)
if !(s.head == :(::))
error("Syntax for specifying a default slot is ~!x::\$predicate, where predicate is a boolean function")
end

name = s.args[1]

push!(keys, name)
tmp = defaultValOfCall(op)
:(DefSlot($(QuoteNode(name)), $(esc(s.args[2])), $(esc(op))), $(esc(tmp)))
end





# matches zero or more terms
# syntax: ~~x
struct Segment{F}
Expand All @@ -37,37 +110,29 @@ function makesegment(s::Expr, keys)
end

name = s.args[1]

push!(keys, name)
:(Segment($(QuoteNode(name)), $(esc(s.args[2]))))
end

makeslot(s::Symbol, keys) = (push!(keys, s); Slot(s))

function makeslot(s::Expr, keys)
if !(s.head == :(::))
error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function")
end

name = s.args[1]

push!(keys, name)
:(Slot($(QuoteNode(name)), $(esc(s.args[2]))))
end

function makepattern(expr, keys)
# parent call is needed to know which default value to give if any default slots are present
function makepattern(expr, keys, parentCall=nothing)
if expr isa Expr
if expr.head === :call
if expr.args[1] === :(~)
if expr.args[2] isa Expr && expr.args[2].args[1] == :(~)
# matches ~~x::predicate
makesegment(expr.args[2].args[2], keys)
elseif expr.args[2] isa Expr && expr.args[2].args[1] == :(!)
# matches ~!x::predicate
makeDefSlot(expr.args[2].args[2], keys, parentCall)
else
# matches ~x::predicate
makeslot(expr.args[2], keys)
end
else
:(term($(map(x->makepattern(x, keys), expr.args)...); type=Any))
# make a pattern for every argument of the expr.
:(term($(map(x->makepattern(x, keys, operation(expr)), expr.args)...); type=Any))
end
elseif expr.head === :ref
:(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any))
Expand Down
35 changes: 35 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,41 @@ end
@eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, [])
end

@testset "Slot matcher with default value" begin
r_sum = @rule (~x + ~!y)^2 => ~y
@test r_sum((a + b)^2) === b
@test r_sum(b^2) === 0

r_mult = @rule ~x * ~!y => ~y
@test r_mult(a * b) === b
@test r_mult(a) === 1

r_mult2 = @rule (~x * ~!y + ~z) => ~y
@test r_mult2(c + a*b) === b
@test r_mult2(c + b) === 1

# here the "normal part" in the defslot_term_matcher is not a symbol but a tree
r_mult3 = @rule (~!x)*(~y + ~z) => ~x
@test r_mult3(a*(c+2)) === a
@test r_mult3(2*(c+2)) === 2
@test r_mult3(c+2) === 1

r_pow = @rule (~x)^(~!m) => ~m
@test r_pow(a^(b+1)) === b+1
@test r_pow(a) === 1
@test r_pow(a+1) === 1

# here the "normal part" in the defslot_term_matcher is not a symbol but a tree
r_pow2 = @rule (~x + ~y)^(~!m) => ~m
@test r_pow2((a+b)^c) === c
@test r_pow2(a+b) === 1

r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c
@test r_mix((a + b*c)^2) === 2 + c
@test r_mix((a + b*c)) === 1 + c
@test r_mix((a + b)) === 2 #1+1
end

using SymbolicUtils: @capture

@testset "Capture form" begin
Expand Down
Loading