From 543a877515441516dfbde2db5f946232aebe8fed Mon Sep 17 00:00:00 2001 From: Martin Holters Date: Wed, 13 Jul 2022 03:32:07 +0200 Subject: [PATCH] Improve evaluation of nested `ComposedFunction`s (#45925) * ~~Add (the equivalent of) `@assume_effects :terminates_globally` to `unwrap_composed`. Although it could be inferred as `Const`, without the annotation, it was not elided for too complex inputs, resulting in unnecessary runtime overhead.~~ EDIT: now #45993 is merged and this part isn't included. * Reverse recursion order in `call_composed`. This prevents potentially changing argument types being piped through the recursion, making inference bail out. With the reversed order, only the tuple of remaining functions is changing during recursion and is becoming strictly simpler, letting inference succeed. Co-authored-by: Shuhei Kadowaki --- base/operators.jl | 11 ++++------- test/operators.jl | 7 +++++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/base/operators.jl b/base/operators.jl index 20e65707ad59dc..8f11e3b5747067 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -1017,14 +1017,11 @@ struct ComposedFunction{O,I} <: Function ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner) end -function (c::ComposedFunction)(x...; kw...) - fs = unwrap_composed(c) - call_composed(fs[1](x...; kw...), tail(fs)...) -end -unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...) +(c::ComposedFunction)(x...; kw...) = call_composed(unwrap_composed(c), x, kw) +unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.outer)..., unwrap_composed(c.inner)...) unwrap_composed(c) = (maybeconstructor(c),) -call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...)) -call_composed(x, f) = f(x) +call_composed(fs, x, kw) = (@inline; fs[1](call_composed(tail(fs), x, kw))) +call_composed(fs::Tuple{Any}, x, kw) = fs[1](x...; kw...) struct Constructor{F} <: Function end (::Constructor{F})(args...; kw...) where {F} = (@inline; F(args...; kw...)) diff --git a/test/operators.jl b/test/operators.jl index 46833e1280eea2..65d5e5726b3120 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -2,6 +2,8 @@ using Random: randstring +include("compiler/irutils.jl") + @testset "ifelse" begin @test ifelse(true, 1, 2) == 1 @test ifelse(false, 1, 2) == 2 @@ -182,6 +184,11 @@ end @test (@inferred g(1)) == ntuple(Returns(1), 13) h = (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ sum @test (@inferred h((1, 2, 3); init = 0.0)) == 6.0 + issue_45877 = reduce(∘, fill(sin,500)) + @test Core.Compiler.is_foldable(Base.infer_effects(Base.unwrap_composed, (typeof(issue_45877),))) + @test fully_eliminated() do + issue_45877(1.0) + end end @testset "function negation" begin