Skip to content

opticstyles: to be or not to be #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ jobs:
fail-fast: false
matrix:
version:
- 1.6 # LTS
- 1
- 1.6
- 1.7
- 1.8
- 1.9
- '1.10'
- '1.11-nightly'
- 'nightly'
os:
- ubuntu-latest
# - macOS-latest
- windows-latest
# - windows-latest
arch:
- x64
steps:
Expand Down
1 change: 0 additions & 1 deletion examples/custom_optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ end
#
using Accessors
struct Keys end
Accessors.OpticStyle(::Type{Keys}) = ModifyBased()
Accessors.modify(f, obj, ::Keys) = mapkeys(f, obj)
# It can be used as follows:
obj = Dict("A" =>1, "B" => 2, "C" => 3)
Expand Down
5 changes: 2 additions & 3 deletions examples/specter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
using Test
using Accessors

import Accessors: modify, OpticStyle
using Accessors: ModifyBased, SetBased, setindex
import Accessors: modify
using Accessors: setindex

# ### Increment all even numbers
# We have the following data and the goal is to increment all nested even numbers.
Expand All @@ -22,7 +22,6 @@ end
mapvals(f, nt::NamedTuple) = map(f, nt)

struct Vals end
OpticStyle(::Type{Vals}) = ModifyBased()
modify(f, obj, ::Vals) = mapvals(f, obj)

# Now we can increment as follows:
Expand Down
4 changes: 2 additions & 2 deletions src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ getall(obj::AbstractString, ::Elements) = collect(obj)
getall(obj, ::Elements) = error("Elements() not supported for $(typeof(obj))")
getall(obj, ::Properties) = getproperties(obj) |> values
getall(obj, o::If) = o.modify_condition(obj) ? (obj,) : ()
getall(obj, o) = OpticStyle(o) == SetBased() ? (o(obj),) : error("`getall` not supported for $o")
getall(obj, o) = (o(obj),)

function setall(obj, ::Properties, vs)
names = propertynames(obj)
Expand All @@ -84,7 +84,7 @@ function setall(obj, o::If, vs)
obj
end
end
setall(obj, o, vs) = OpticStyle(o) == SetBased() ? set(obj, o, only(vs)) : error("`setall` not supported for $o")
setall(obj, o, vs) = set(obj, o, only(vs))


# implementations for composite optics
Expand Down
83 changes: 14 additions & 69 deletions src/optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,49 +130,12 @@ const ComposedOptic{Outer,Inner} = ComposedFunction{Outer,Inner}
outertype(::Type{ComposedOptic{Outer,Inner}}) where {Outer,Inner} = Outer
innertype(::Type{ComposedOptic{Outer,Inner}}) where {Outer,Inner} = Inner

# TODO better name
# also better way to organize traits will
# probably only emerge over time
#
# TODO
# There is an inference regression as of Julia v1.7.0
# if recursion is combined with trait based dispatch
# https://github.com/JuliaLang/julia/issues/43296

abstract type OpticStyle end
struct ModifyBased <: OpticStyle end
struct SetBased <: OpticStyle end
# Base.@pure OpticStyle(obj) = OpticStyle(typeof(obj))
function OpticStyle(optic::T) where {T}
OpticStyle(T)
end
# defining lenses should be very lightweight
# e.g. only a single `set` implementation
# so we choose this as the default trait
OpticStyle(::Type{T}) where {T} = SetBased()

function OpticStyle(::Type{ComposedOptic{O,I}}) where {O,I}
composed_optic_style(OpticStyle(O), OpticStyle(I))
end
composed_optic_style(::SetBased, ::SetBased) = SetBased()
composed_optic_style(::ModifyBased, ::SetBased) = ModifyBased()
composed_optic_style(::SetBased, ::ModifyBased) = ModifyBased()
composed_optic_style(::ModifyBased, ::ModifyBased) = ModifyBased()

@inline function set(obj, optic::O, val) where {O}
_set(obj, optic, val, OpticStyle(O))
end

function _set(obj, optic, val, ::SetBased)
inv_func = inverse(optic)
if !(inv_func isa NoInverse)
return inv_func(val)
end
Optic = typeof(optic)
error("""
This should be unreachable. You probably need to overload
`Accessors.set(obj, ::$Optic, val)
""")
modify(Returns(val), obj, optic)
end

if VERSION < v"1.7"
Expand All @@ -184,38 +147,24 @@ else
using Base: Returns
end

@inline function _set(obj, optic, val, ::ModifyBased)
modify(Returns(val), obj, optic)
end

@inline function _set(obj, optic::ComposedOptic, val, ::SetBased)
inner_obj = optic.inner(obj)
inner_val = set(inner_obj, optic.outer, val)
set(obj, optic.inner, inner_val)
end

@inline function modify(f, obj, optic::O) where {O}
_modify(f, obj, optic, OpticStyle(O))
end

function _modify(f, obj, optic, ::ModifyBased)
Optic = typeof(optic)
error("""
This should be unreachable. You probably need to overload:
`Accessors.modify(f, obj, ::$Optic)`
""")
set(obj, optic, f(optic(obj)))
end

function _modify(f, obj, optic::ComposedOptic, ::ModifyBased)
otr = optic.outer
inr = optic.inner
modify(obj, inr) do o1
modify(f, o1, otr)
@inline function set(obj, optic::ComposedOptic, val)
optics = decompose(optic)
_modifyc(obj, Base.tail(optics)) do o1
set(o1, first(optics), val)
end
end
@inline modify(f, obj, optic::ComposedOptic) = _modifyc(f, obj, decompose(optic))

@inline function _modify(f, obj, optic, ::SetBased)
set(obj, optic, f(optic(obj)))
@inline _modifyc(f, obj, os::Tuple{}) = f(obj)
for N in [1:10; :(<: Any)]
@eval @inline _modifyc(f, obj, os::NTuple{$N,Any}) =
modify(obj, last(os)) do o1
_modifyc(f, o1, Base.front(os))
end
end

function delete(obj, optic::ComposedOptic)
Expand Down Expand Up @@ -249,7 +198,6 @@ julia> modify(x -> 2x, obj, Elements())
```
"""
struct Elements end
OpticStyle(::Type{<:Elements}) = ModifyBased()

modify(f, obj, ::Elements) = map(f, obj)
# sets and dicts don't support map(), but still have the concept of elements:
Expand All @@ -275,7 +223,6 @@ $EXPERIMENTAL
struct If{C}
modify_condition::C
end
OpticStyle(::Type{<:If}) = ModifyBased()

function modify(f, obj, w::If)
if w.modify_condition(obj)
Expand Down Expand Up @@ -338,7 +285,6 @@ julia> modify(x -> 2x, obj, Properties())
Based on [`mapproperties`](@ref).
"""
struct Properties end
OpticStyle(::Type{<:Properties}) = ModifyBased()
modify(f, o, ::Properties) = mapproperties(f, o)

"""
Expand All @@ -365,9 +311,8 @@ struct Recursive{Descent,Optic}
descent_condition::Descent
optic::Optic
end
OpticStyle(::Type{Recursive{D,O}}) where {D,O} = ModifyBased() # Is this a good idea?

function _modify(f, obj, r::Recursive, ::ModifyBased)
function modify(f, obj, r::Recursive)
modify(obj, r.optic) do o
if r.descent_condition(o)
modify(f, o, r)
Expand Down
6 changes: 2 additions & 4 deletions test/perf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,14 @@ end
println("Right associative composition: $b_right")

@test b_default.allocs == 0
if VERSION >= v"1.10-"
@test_broken b_right.allocs == 0
elseif VERSION >= v"1.7"
if VERSION >= v"1.7"
@test b_right.allocs == 0
else
@test_broken right.allocs == 0
@test b_right.time > 2b_default.time
end
@test b_left.allocs == 0
@test b_left.time ≈ b_default.time rtol=0.8
@test b_right.time ≈ b_default.time rtol=0.8
end

end
22 changes: 6 additions & 16 deletions test/test_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ end
s = @set t.b.b.a.a = 5
@test t === T(1, T(2, T(T(4,4),3)))
@test s === T(1, T(2, T(T(5, 4), 3)))
@test_throws ArgumentError @set t.b.b.a.a.a = 3
@test_throws Exception @set t.b.b.a.a.a = 3

t = T(1,2)
@test T(1, T(1,2)) === @set t.b = T(1,2)
@test_throws ArgumentError @set t.c = 3
@test_throws Exception @set t.c = 3

t = T(T(2,2), 1)
s = @set t.a.a = 3
Expand Down Expand Up @@ -220,19 +220,9 @@ end
((@optic _.b.a.b[end]), 4.0),
((@optic _.b.a.b[end÷2+1]), 4.0),
]
if VERSION < v"1.7" || VERSION >= v"1.10-"
@inferred lens(obj)
@inferred set(obj, lens, val)
@inferred modify(identity, obj, lens)
else
@inferred lens(obj)
@inferred set(obj, lens, val)
@test_broken begin
# https://github.com/JuliaLang/julia/issues/43296
@inferred modify(identity, obj, lens)
true
end
end
@inferred lens(obj)
@inferred set(obj, lens, val)
@inferred modify(identity, obj, lens)
end
end

Expand Down Expand Up @@ -601,7 +591,7 @@ end

@testset "@accessor" begin
s = MyStruct((a=123,))
@test strip(string(@doc(my_x))) == "Documentation for my_x"
# @test strip(string(@doc(my_x))) == "Documentation for my_x"
@test (@set my_x(s) = 456) === MyStruct(456)
@test (@set +s = 456) === MyStruct((a=5-456,))
test_getset_laws(my_x, s, 456, "1")
Expand Down
4 changes: 2 additions & 2 deletions test/test_extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ VERSION >= v"1.9-" && @testset "StructArrays" begin
sa = @set sb.a = 1:3
@test sa.a === 1:3
@test sa.b === sb.b
@test_throws ArgumentError @set sb.c = 1:3
@test_throws Exception @set sb.c = 1:3
sd = @delete sb.a
@test sd::StructArray == StructArray(b=10:12)
@test_throws "only eltypes with fields" @delete s.a
Expand All @@ -172,7 +172,7 @@ VERSION >= v"1.9-" && @testset "StructArrays" begin
@test @inferred(set(s, PropertyLens(:a), 10:11))::StructArray == StructArray([S(10, 2), S(11, 4)])
@test @inferred(set(s, PropertyLens(:a), [:a, :b]))::StructArray == StructArray([S(:a, 2), S(:b, 4)])

@test_throws "need to overload" set(s, propertynames, (:x, :y))
@test_throws Exception set(s, propertynames, (:x, :y))
s = StructArray(x=[1, 2], y=[:a, :b])
test_getset_laws(propertynames, s, (:u, :v), (1, 2))
test_getset_laws(propertynames, s, (1, 2), (:u, :v))
Expand Down
10 changes: 9 additions & 1 deletion test/test_functionlenses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ end
# @optic(parse(Int, _)) isa Base.Fix1{typeof(parse), Type{T}} where {T}
# doesn't hold
@test @inferred(modify(x -> -2x, "3", @optic parse(Int, _))) == "-6"
@test_throws ErrorException modify(log10, "100", @optic parse(Int, _))
@test_throws Exception modify(log10, "100", @optic parse(Int, _))
@test modify(log10, "100", @optic parse(Float64, _)) == "2.0"
test_getset_laws(@optic(parse(Int, _)), "3", -10, 123)
test_getset_laws(@optic(parse(Float64, _)), "3.0", -10., 123.)
Expand Down Expand Up @@ -391,4 +391,12 @@ end
test_getset_laws(o2, x, 2, -3)
end

@testset "non-callable" begin
struct MyF end
Accessors.set(x, ::MyF, y) = y + 1
Accessors.modify(f, x, ::MyF) = f(x)
@test set(1, (@o _ + 2 |> MyF()), 3) == 2
@test modify(x -> 10x, 1, (@o _ + 2 |> MyF())) == 28
end

end # module