Skip to content

Commit

Permalink
Create Base.Fix as general Fix1/Fix2 for partially-applied func…
Browse files Browse the repository at this point in the history
…tions (JuliaLang#54653)

This PR generalises `Base.Fix1` and `Base.Fix2` to `Base.Fix{N}`, to
allow fixing a single positional argument of a function.

With this change, the implementation of these is simply

```julia
const Fix1{F,T} = Fix{1,F,T}
const Fix2{F,T} = Fix{2,F,T}
```

Along with the PR I also add a larger suite of unittests for all three
of these functions to complement the existing tests for `Fix1`/`Fix2`.

### Context

There are multiple motivations for this generalization.
**By creating a more general `Fix{N}` type, there is no preferential
treatment of certain types of functions:**

- (i) No limitation that you can only fix positions 1-2. You can now fix
any position `n`.
- (ii) No asymmetry between 2-argument and n-argument functions. You can
now fix an argument for functions with any number of arguments.

Think of this like if `Base` only had `Vector{T}` and `Matrix{T}`, and
you wished to generalise it to `Array{T,N}`.
It is an analogous situation here: `Fix1` and `Fix2` are now *aliases*
of `Fix{N}`.

- **Convenience**:
- `Base.Fix1` and `Base.Fix2` are useful shorthands for creating simple
anonymous functions without compiling new functions.
- They are common throughout the Julia ecosystem as a shorthand for
filling arguments:
- `Fix1`
https://github.com/search?q=Base.Fix1+language%3Ajulia&type=code
- `Fix2`
https://github.com/search?q=Base.Fix2+language%3Ajulia&type=code
- **Less Compilation**:
- Using `Fix*` reduces the need for compilation of repeatedly-used
anonymous functions (which can often trigger compilation of new
functions).
- **Type Stability**:
- `Fix`, like `Fix1` and `Fix2`, captures variables in a struct,
encouraging users to use a functional paradigm for closures, preventing
any potential type instabilities from boxed variables within an
anonymous function.
- **Easier Functional Programming**:
- Allows for a stronger functional programming paradigm by supporting
partial functions with _any number of arguments_.

Note that this refactors `Fix1` and `Fix2` to be equal to `Fix{1}` and
`Fix{2}` respectively, rather than separate structs. This is backwards
compatible.

Also note that this does not constrain future generalisations of
`Fix{n}` for multiple arguments. `Fix{1,F,T}` is the clear
generalisation of `Fix1{F,T}`, so this isn't major new syntax choices.
But in a future PR you could have, e.g., `Fix{(n1,n2)}` for multiple
arguments, and it would still be backwards-compatible with this.

---------

Co-authored-by: Dilum Aluthge <dilum@aluthge.com>
Co-authored-by: Lilith Orion Hafner <lilithhafner@gmail.com>
Co-authored-by: Alexander Plavin <alexander@plav.in>
Co-authored-by: Neven Sajko <s@purelymail.com>
  • Loading branch information
5 people authored and lazarusA committed Aug 17, 2024
1 parent 86d13ab commit 79e992e
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 23 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ New library functions
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).

New library features
--------------------
Expand Down
57 changes: 36 additions & 21 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1154,40 +1154,55 @@ julia> filter(!isletter, str)
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f

"""
Fix1(f, x)
Fix{N}(f, x)
A type representing a partially-applied version of the two-argument function
`f`, with the first argument fixed to the value "x". In other words,
`Fix1(f, x)` behaves similarly to `y->f(x, y)`.
A type representing a partially-applied version of a function `f`, with the argument
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
See also [`Fix2`](@ref Base.Fix2).
!!! compat "Julia 1.12"
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
are available earlier.
!!! note
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
available arguments, rather than an absolute ordering on the target function. For example,
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
fixes the first and third arg.
"""
struct Fix1{F,T} <: Function
struct Fix{N,F,T} <: Function
f::F
x::T

Fix1(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
Fix1(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
function Fix{N}(f::F, x) where {N,F}
if !(N isa Int)
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
elseif N < 1
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
end
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
end
end

(f::Fix1)(y) = f.f(f.x, y)
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
end

"""
Fix2(f, x)
# Special cases for improved constant propagation
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)

A type representing a partially-applied version of the two-argument function
`f`, with the second argument fixed to the value "x". In other words,
`Fix2(f, x)` behaves similarly to `y->f(y, x)`.
"""
struct Fix2{F,T} <: Function
f::F
x::T
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
"""
const Fix1{F,T} = Fix{1,F,T}

Fix2(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
Fix2(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
end
"""
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
"""
const Fix2{F,T} = Fix{2,F,T}

(f::Fix2)(y) = f.f(y, f.x)

"""
isequal(x)
Expand Down
1 change: 1 addition & 0 deletions base/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public
AsyncCondition,
CodeUnits,
Event,
Fix,
Fix1,
Fix2,
Generator,
Expand Down
1 change: 1 addition & 0 deletions doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ Base.:(|>)
Base.:(∘)
Base.ComposedFunction
Base.splat
Base.Fix
Base.Fix1
Base.Fix2
```
Expand Down
4 changes: 2 additions & 2 deletions stdlib/REPL/test/repl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1216,9 +1216,9 @@ global some_undef_global
@test occursin("does not exist", sprint(show, help_result("..")))
# test that helpmode is sensitive to contextual module
@test occursin("No documentation found", sprint(show, help_result("Fix2", Main)))
@test occursin("A type representing a partially-applied version", # exact string may change
@test occursin("Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).", # exact string may change
sprint(show, help_result("Base.Fix2", Main)))
@test occursin("A type representing a partially-applied version", # exact string may change
@test occursin("Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).", # exact string may change
sprint(show, help_result("Fix2", Base)))


Expand Down
126 changes: 126 additions & 0 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,129 @@ end
let (:)(a,b) = (i for i in Base.:(:)(1,10) if i%2==0)
@test Int8[ i for i = 1:2 ] == [2,4,6,8,10]
end

@testset "Basic tests of Fix1, Fix2, and Fix" begin
function test_fix1(Fix1=Base.Fix1)
increment = Fix1(+, 1)
@test increment(5) == 6
@test increment(-1) == 0
@test increment(0) == 1
@test map(increment, [1, 2, 3]) == [2, 3, 4]

concat_with_hello = Fix1(*, "Hello ")
@test concat_with_hello("World!") == "Hello World!"
# Make sure inference is good:
@inferred concat_with_hello("World!")

one_divided_by = Fix1(/, 1)
@test one_divided_by(10) == 1/10.0
@test one_divided_by(-5) == 1/-5.0

return nothing
end

function test_fix2(Fix2=Base.Fix2)
return_second = Fix2((x, y) -> y, 999)
@test return_second(10) == 999
@inferred return_second(10)
@test return_second(-5) == 999

divide_by_two = Fix2(/, 2)
@test map(divide_by_two, (2, 4, 6)) == (1.0, 2.0, 3.0)
@inferred map(divide_by_two, (2, 4, 6))

concat_with_world = Fix2(*, " World!")
@test concat_with_world("Hello") == "Hello World!"
@inferred concat_with_world("Hello World!")

return nothing
end

# Test with normal Base.Fix1 and Base.Fix2
test_fix1()
test_fix2()

# Now, repeat the Fix1 and Fix2 tests, but
# with a Fix lambda function used in their place
test_fix1((op, arg) -> Base.Fix{1}(op, arg))
test_fix2((op, arg) -> Base.Fix{2}(op, arg))

# Now, we do more complex tests of Fix:
let Fix=Base.Fix
@testset "Argument Fixation" begin
let f = (x, y, z) -> x + y * z
fixed_f1 = Fix{1}(f, 10)
@test fixed_f1(2, 3) == 10 + 2 * 3

fixed_f2 = Fix{2}(f, 5)
@test fixed_f2(1, 4) == 1 + 5 * 4

fixed_f3 = Fix{3}(f, 3)
@test fixed_f3(1, 2) == 1 + 2 * 3
end
end
@testset "Helpful errors" begin
let g = (x, y) -> x - y
# Test minimum N
fixed_g1 = Fix{1}(g, 100)
@test fixed_g1(40) == 100 - 40

# Test maximum N
fixed_g2 = Fix{2}(g, 100)
@test fixed_g2(150) == 150 - 100

# One over
fixed_g3 = Fix{3}(g, 100)
@test_throws ArgumentError("expected at least 2 arguments to `Fix{3}`, but got 1") fixed_g3(1)
end
end
@testset "Type Stability and Inference" begin
let h = (x, y) -> x / y
fixed_h = Fix{2}(h, 2.0)
@test @inferred(fixed_h(4.0)) == 2.0
end
end
@testset "Interaction with varargs" begin
vararg_f = (x, y, z...) -> x + 10 * y + sum(z; init=zero(x))
fixed_vararg_f = Fix{2}(vararg_f, 6)

# Can call with variable number of arguments:
@test fixed_vararg_f(1, 2, 3, 4) == 1 + 10 * 6 + sum((2, 3, 4))
@inferred fixed_vararg_f(1, 2, 3, 4)
@test fixed_vararg_f(5) == 5 + 10 * 6
@inferred fixed_vararg_f(5)
end
@testset "Errors should propagate normally" begin
error_f = (x, y) -> sin(x * y)
fixed_error_f = Fix{2}(error_f, Inf)
@test_throws DomainError fixed_error_f(10)
end
@testset "Chaining Fix together" begin
f1 = Fix{1}(*, "1")
f2 = Fix{1}(f1, "2")
f3 = Fix{1}(f2, "3")
@test f3() == "123"

g1 = Fix{2}(*, "1")
g2 = Fix{2}(g1, "2")
g3 = Fix{2}(g2, "3")
@test g3("") == "123"
end
@testset "Zero arguments" begin
f = Fix{1}(x -> x, 'a')
@test f() == 'a'
end
@testset "Dummy-proofing" begin
@test_throws ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got 0") Fix{0}(>, 1)
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `0.5::Float64`") Fix{0.5}(>, 1)
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `1::UInt64`") Fix{UInt64(1)}(>, 1)
end
@testset "Specialize to structs not in `Base`" begin
struct MyStruct
x::Int
end
f = Fix{1}(MyStruct, 1)
@test f isa Fix{1,Type{MyStruct},Int}
end
end
end

0 comments on commit 79e992e

Please sign in to comment.