-
-
Notifications
You must be signed in to change notification settings - Fork 15
/
functor.jl
115 lines (88 loc) · 2.65 KB
/
functor.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
functor(T, x) = (), _ -> x
functor(x) = functor(typeof(x), x)
functor(::Type{<:Tuple}, x) = x, y -> y
functor(::Type{<:NamedTuple}, x) = x, y -> y
functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
function makefunctor(m::Module, T, fs = fieldnames(T))
yᵢ = 0
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]
@eval m begin
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
end
end
function functorm(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
end
"""
isleaf(x)
Return true if `x` has no [`children`](@ref) according to [`functor`](@ref).
"""
isleaf(x) = children(x) === ()
"""
children(x)
Return the children of `x` as defined by [`functor`](@ref).
Equivalent to `functor(x)[1]`.
"""
children(x) = functor(x)[1]
function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end
# See https://github.com/FluxML/Functors.jl/issues/2 for a discussion regarding the need for
# cache.
function fmap(f, x; exclude = isleaf, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache, exclude = exclude), x)
cache[x] = y
return y
end
"""
fcollect(x; exclude = v -> false)
Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
and collecting the results into a flat array.
Doesn't recurse inside branches rooted at nodes `v`
for which `exclude(v) == true`.
In such cases, the root `v` is also excluded from the result.
By default, `exclude` always yields `false`.
See also [`children`](@ref).
# Examples
```jldoctest
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> struct NoChildren; x; y; end
julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b))
julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
NoChildren(:a, :b)
julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
NoChildren(:a, :b)
julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
```
"""
function fcollect(x; cache = [], exclude = v -> false)
x in cache && return cache
if !exclude(x)
push!(cache, x)
foreach(y -> fcollect(y; cache = cache, exclude = exclude), children(x))
end
return cache
end