Skip to content

Commit

Permalink
introduce @noinfer macro to tell the compiler to avoid excess infer…
Browse files Browse the repository at this point in the history
…ence

This commit introduces new compiler annotation named `@noinfer`, which
requests the compiler to avoid excess inference.

In order to discuss `@noinfer`, it would help a lot to understand the
behavior of `@nospecialize`.

Its docstring says simply:
> This is only a hint for the compiler to avoid excess code generation.

More specifically, it works by _suppressing dispatches_ with complex
runtime types of the annotated arguments. This could be understood with
the example below:
```julia
julia> function invokef(f, itr)
           local r = 0
           r += f(itr[1])
           r += f(itr[2])
           r += f(itr[3])
           r
       end;

julia> _isa = isa; # just for the sake of explanation, global variable to prevent inling
julia> f(a) = _isa(a, Function);
julia> g(@nospecialize a) = _isa(a, Function);
julia> dispatchonly = Any[sin, muladd, nothing]; # untyped container can cause excessive runtime dispatch

julia> @code_typed invokef(f, dispatchonly)
CodeInfo(
1 ─ %1  = π (0, Int64)
│   %2  = Base.arrayref(true, itr, 1)::Any
│   %3  = (f)(%2)::Any
│   %4  = (%1 + %3)::Any
│   %5  = Base.arrayref(true, itr, 2)::Any
│   %6  = (f)(%5)::Any
│   %7  = (%4 + %6)::Any
│   %8  = Base.arrayref(true, itr, 3)::Any
│   %9  = (f)(%8)::Any
│   %10 = (%7 + %9)::Any
└──       return %10
) => Any

julia> @code_typed invokef(g, dispatchonly)
CodeInfo(
1 ─ %1  = π (0, Int64)
│   %2  = Base.arrayref(true, itr, 1)::Any
│   %3  = invoke f(%2::Any)::Any
│   %4  = (%1 + %3)::Any
│   %5  = Base.arrayref(true, itr, 2)::Any
│   %6  = invoke f(%5::Any)::Any
│   %7  = (%4 + %6)::Any
│   %8  = Base.arrayref(true, itr, 3)::Any
│   %9  = invoke f(%8::Any)::Any
│   %10 = (%7 + %9)::Any
└──       return %10
) => Any
```

The calls of `f` remain to be `:call` expression (thus dispatched and
compiled at runtime) while the calls of `g` are resolved as `:invoke`
expressions. This is because `@nospecialize` requests the compiler to
give up compiling `g` with concrete argument types but with precisely
declared argument types, and in this way `invokef(g, dispatchonly)` will
avoid runtime dispatches and accompanying JIT compilations (i.e. "excess
code generation").

The problem here is, it influences dispatch only, does not intervene
into inference in anyway. So there is still a possibility of "excess
inference" when the compiler sees a considerable complexity of argument
types  during inference:
```julia
julia> withinfernce = tuple(sin, muladd, "foo"); # typed container can cause excessive inference

julia> @time @code_typed invokef(f, withinfernce);
  0.000812 seconds (3.77 k allocations: 217.938 KiB, 94.34% compilation time)

julia> @time @code_typed invokef(g, withinfernce);
  0.000753 seconds (3.77 k allocations: 218.047 KiB, 92.42% compilation time)
```

The purpose of this PR is basically to provide a more drastic way to
avoid excess compilation.

Here are some ideas to implement the functionality:
1. make `@nospecialize` avoid inference also
2. add noinfer effect when `@nospecialize`d method is annotated as `@noinline` also
3. implement as `@pure`-like boolean annotation to request noinfer effect on top of `@nospecialize`
4. implement as annotation that is orthogonal to `@nospecialize`

After trying 1 ~ 3., I decided to submit 3. for now, because I think the
interface is ready to be experimented.

This is almost same as what Jameson has done at <vtjnash@8ab7b6b>.
It turned out that this approach performs very badly because some of
`@nospecialize`'d arguments still need inference to perform reasonably.
For example, it's obvious that the following definition of
`getindex(@nospecialize(t::Tuple), i::Int)` would perform very badly if
`@nospecialize` blocks inference, because of a lack of useful type
information for succeeding optimizations:
<https://github.com/JuliaLang/julia/blob/12d364e8249a07097a233ce7ea2886002459cc50/base/tuple.jl#L29-L30>

The important observation is that we often use `@nospecialize` even when
we expect inference to forward type and constant information.
Adversely, we may be able to exploit the fact that we usually don't
expect inference to forward information to a callee when we annotate it
as `@noinline`.
So the idea is to enable the inference suppression when `@nospecialize`'d
method is annotated as `@noinline` also.

It's a reasonable choice, and could be implemented efficiently after <#41922>.
But it sounds a bit weird to me to associate no infer effect with
`@noinline`, and I also think there may be some cases we want to inline
a method while _partially_ avoiding inference, e.g.:
```julia
@noinline function twof(@nospecialize(f), n) # we really want not to
inline this method body ?
    if occursin('+', string(typeof(f).name.name::Symbol))
        2 + n
    elseif occursin('*', string(typeof(f).name.name::Symbol))
        2n
    else
        zero(n)
    end
end
```

So this is what this commit implements. It basically replaces the previous
`@noinline` flag with newly-introduced annotation named `@noinfer`. It's
still associated with `@nospecialize` and it only has effect when used
together with `@nospecialize`, but now it's not associated to `@noinline`
at least, and it would help us reason about the behavior of `@noinfer`
and experiment its effect more reliably:
```julia
Base.@noinfer function twof(@nospecialize(f), n) # the compiler may or not inline this method
    if occursin('+', string(typeof(f).name.name::Symbol))
        2 + n
    elseif occursin('*', string(typeof(f).name.name::Symbol))
        2n
    else
        zero(n)
    end
end
```

Actually, we can have `@nospecialize` and `@noinfer` separately, and it
would allow us to configure compilation strategies in a more
fine-grained way.
```julia
function noinfspec(Base.@noinfer(f), @nospecialize(g))
    ...
end
```

I'm fine with this approach, if initial experiments show `@noinfer` is
useful.

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>
Co-authored-by: Tim Holy <tim.holy@gmail.com>
  • Loading branch information
3 people committed Apr 12, 2023
1 parent b4cc5c2 commit f293f89
Show file tree
Hide file tree
Showing 15 changed files with 205 additions and 28 deletions.
10 changes: 9 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,10 @@ function abstract_call_method(interp::AbstractInterpreter,
sigtuple = unwrap_unionall(sig)
sigtuple isa DataType || return MethodCallResult(Any, false, false, nothing, Effects())

if is_noinfer(method)
sig = get_nospecialize_sig(method, sig, sparams)
end

# Limit argument type tuple growth of functions:
# look through the parents list to see if there's a call to the same method
# and from the same method.
Expand Down Expand Up @@ -1075,7 +1079,11 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
return nothing
end
force |= all_overridden
mi = specialize_method(match; preexisting=!force)
if is_noinfer(method)
mi = specialize_method_noinfer(match; preexisting=!force)
else
mi = specialize_method(match; preexisting=!force)
end
if mi === nothing
add_remark!(interp, sv, "[constprop] Failed to specialize")
return nothing
Expand Down
6 changes: 5 additions & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,11 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
end

# See if there exists a specialization for this method signature
mi = specialize_method(match; preexisting=true) # Union{Nothing, MethodInstance}
if is_noinfer(method)
mi = specialize_method_noinfer(match; preexisting=true)
else
mi = specialize_method(match; preexisting=true)
end
if mi === nothing
et = InliningEdgeTracker(state.et, invokesig)
effects = info_effects(nothing, match, state)
Expand Down
26 changes: 25 additions & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ function is_inlineable_constant(@nospecialize(x))
return count_const_size(x) <= MAX_INLINE_CONST_SIZE
end

is_nospecialized(method::Method) = method.nospecialize 0

is_noinfer(method::Method) = method.noinfer && is_nospecialized(method)
# is_noinfer(method::Method) = is_nospecialized(method) && is_declared_noinline(method)

###########################
# MethodInstance/CodeInfo #
###########################
Expand Down Expand Up @@ -158,6 +163,19 @@ function get_compileable_sig(method::Method, @nospecialize(atype), sparams::Simp
mt, atype, sparams, method)
end

function get_nospecialize_sig(method::Method, @nospecialize(atype), sparams::SimpleVector)
if isa(atype, UnionAll)
atype, sparams = normalize_typevars(method, atype, sparams)
end
isa(atype, DataType) || return method.sig
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
mt === nothing && return method.sig
# TODO allow uncompileable signatures to be returned here
sig = ccall(:jl_normalize_to_compilable_sig, Any, (Any, Any, Any, Any),
mt, atype, sparams, method)
return sig === nothing ? method.sig : sig
end

isa_compileable_sig(@nospecialize(atype), sparams::SimpleVector, method::Method) =
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any, Any), atype, sparams, method))

Expand Down Expand Up @@ -199,7 +217,8 @@ function normalize_typevars(method::Method, @nospecialize(atype), sparams::Simpl
end

# get a handle to the unique specialization object representing a particular instantiation of a call
function specialize_method(method::Method, @nospecialize(atype), sparams::SimpleVector; preexisting::Bool=false, compilesig::Bool=false)
function specialize_method(method::Method, @nospecialize(atype), sparams::SimpleVector;
preexisting::Bool=false, compilesig::Bool=false)
if isa(atype, UnionAll)
atype, sparams = normalize_typevars(method, atype, sparams)
end
Expand All @@ -225,6 +244,11 @@ function specialize_method(match::MethodMatch; kwargs...)
return specialize_method(match.method, match.spec_types, match.sparams; kwargs...)
end

function specialize_method_noinfer((; method, spec_types, sparams)::MethodMatch; kwargs...)
atype = get_nospecialize_sig(method, spec_types, sparams)
return specialize_method(method, atype, sparams; kwargs...)
end

"""
is_declared_inline(method::Method) -> Bool
Expand Down
3 changes: 2 additions & 1 deletion base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ f(y) = [x for x in y]
!!! note
`@nospecialize` affects code generation but not inference: it limits the diversity
of the resulting native code, but it does not impose any limitations (beyond the
standard ones) on type-inference.
standard ones) on type-inference. Use [`Base.@noinfer`](@ref) together with
`@nospecialize` to additionally suppress inference.
# Example
Expand Down
34 changes: 33 additions & 1 deletion base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ macro noinline(x)
return annotate_meta_def_or_block(x, :noinline)
end


"""
@constprop setting [ex]
Expand Down Expand Up @@ -753,6 +752,39 @@ function compute_assumed_setting(@nospecialize(setting), val::Bool=true)
end
end

"""
Base.@noinfer function f(args...)
@nospecialize ...
...
end
Base.@noinfer f(@nospecialize args...) = ...
Tells the compiler to infer `f` using the declared types of `@nospecialize`d arguments.
This can be used to limit the number of compiler-generated specializations during inference.
# Example
```julia
julia> f(A::AbstractArray) = g(A)
f (generic function with 1 method)
julia> @noinline Base.@noinfer g(@nospecialize(A::AbstractArray)) = A[1]
g (generic function with 1 method)
julia> @code_typed f([1.0])
CodeInfo(
1 ─ %1 = invoke Main.g(_2::AbstractArray)::Any
└── return %1
) => Any
```
In this example, `f` will be inferred for each specific type of `A`,
but `g` will only be inferred once.
"""
macro noinfer(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :noinfer) : ex)
end

"""
@propagate_inbounds
Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ Base.@inline
Base.@noinline
Base.@nospecialize
Base.@specialize
Base.@noinfer
Base.@constprop
Base.gensym
Base.@gensym
var"name"
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ JL_DLLEXPORT jl_sym_t *jl_aggressive_constprop_sym;
JL_DLLEXPORT jl_sym_t *jl_no_constprop_sym;
JL_DLLEXPORT jl_sym_t *jl_purity_sym;
JL_DLLEXPORT jl_sym_t *jl_nospecialize_sym;
JL_DLLEXPORT jl_sym_t *jl_noinfer_sym;
JL_DLLEXPORT jl_sym_t *jl_macrocall_sym;
JL_DLLEXPORT jl_sym_t *jl_colon_sym;
JL_DLLEXPORT jl_sym_t *jl_hygienicscope_sym;
Expand Down Expand Up @@ -342,6 +343,7 @@ void jl_init_common_symbols(void)
jl_isdefined_sym = jl_symbol("isdefined");
jl_nospecialize_sym = jl_symbol("nospecialize");
jl_specialize_sym = jl_symbol("specialize");
jl_noinfer_sym = jl_symbol("noinfer");
jl_optlevel_sym = jl_symbol("optlevel");
jl_compile_sym = jl_symbol("compile");
jl_force_compile_sym = jl_symbol("force_compile");
Expand Down
10 changes: 6 additions & 4 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,14 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
}
}

static jl_code_info_flags_t code_info_flags(uint8_t inferred, uint8_t propagate_inbounds,
uint8_t has_fcall, uint8_t inlining, uint8_t constprop)
static jl_code_info_flags_t code_info_flags(uint8_t inferred, uint8_t propagate_inbounds, uint8_t has_fcall,
uint8_t noinfer, uint8_t inlining, uint8_t constprop)
{
jl_code_info_flags_t flags;
flags.bits.inferred = inferred;
flags.bits.propagate_inbounds = propagate_inbounds;
flags.bits.has_fcall = has_fcall;
flags.bits.noinfer = noinfer;
flags.bits.inlining = inlining;
flags.bits.constprop = constprop;
return flags;
Expand Down Expand Up @@ -780,8 +781,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
1
};

jl_code_info_flags_t flags = code_info_flags(code->inferred, code->propagate_inbounds,
code->has_fcall, code->inlining, code->constprop);
jl_code_info_flags_t flags = code_info_flags(code->inferred, code->propagate_inbounds, code->has_fcall,
code->noinfer, code->inlining, code->constprop);
write_uint8(s.s, flags.packed);
write_uint8(s.s, code->purity.bits);
write_uint16(s.s, code->inlining_cost);
Expand Down Expand Up @@ -880,6 +881,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
code->inferred = flags.bits.inferred;
code->propagate_inbounds = flags.bits.propagate_inbounds;
code->has_fcall = flags.bits.has_fcall;
code->noinfer = flags.bits.noinfer;
code->purity.bits = read_uint8(s.s);
code->inlining_cost = read_uint16(s.s);

Expand Down
14 changes: 9 additions & 5 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2681,7 +2681,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(21,
jl_perm_symsvec(22,
"code",
"codelocs",
"ssavaluetypes",
Expand All @@ -2699,11 +2699,12 @@ void jl_init_types(void) JL_GC_DISABLED
"inferred",
"propagate_inbounds",
"has_fcall",
"noinfer",
"inlining",
"constprop",
"purity",
"inlining_cost"),
jl_svec(21,
jl_svec(22,
jl_array_any_type,
jl_array_int32_type,
jl_any_type,
Expand All @@ -2721,17 +2722,18 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint8_type,
jl_uint16_type),
jl_emptysvec,
0, 1, 20);
0, 1, 22);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(28,
jl_perm_symsvec(29,
"name",
"module",
"file",
Expand All @@ -2758,9 +2760,10 @@ void jl_init_types(void) JL_GC_DISABLED
"nkw",
"isva",
"is_for_opaque_closure",
"noinfer",
"constprop",
"purity"),
jl_svec(28,
jl_svec(29,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand All @@ -2787,6 +2790,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_int32_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type),
jl_emptysvec,
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ typedef struct _jl_code_info_t {
uint8_t inferred;
uint8_t propagate_inbounds;
uint8_t has_fcall;
uint8_t noinfer;
// uint8 settings
uint8_t inlining; // 0 = default; 1 = @inline; 2 = @noinline
uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none
Expand Down Expand Up @@ -343,6 +344,7 @@ typedef struct _jl_method_t {
// various boolean properties
uint8_t isva;
uint8_t is_for_opaque_closure;
uint8_t noinfer;
// uint8 settings
uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none

Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ typedef struct {
uint8_t inferred:1;
uint8_t propagate_inbounds:1;
uint8_t has_fcall:1;
uint8_t noinfer:1;
uint8_t inlining:2; // 0 = use heuristic; 1 = aggressive; 2 = none
uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none
} jl_code_info_flags_bitfield_t;
Expand Down Expand Up @@ -1566,6 +1567,7 @@ extern JL_DLLEXPORT jl_sym_t *jl_aggressive_constprop_sym;
extern JL_DLLEXPORT jl_sym_t *jl_no_constprop_sym;
extern JL_DLLEXPORT jl_sym_t *jl_purity_sym;
extern JL_DLLEXPORT jl_sym_t *jl_nospecialize_sym;
extern JL_DLLEXPORT jl_sym_t *jl_noinfer_sym;
extern JL_DLLEXPORT jl_sym_t *jl_macrocall_sym;
extern JL_DLLEXPORT jl_sym_t *jl_colon_sym;
extern JL_DLLEXPORT jl_sym_t *jl_hygienicscope_sym;
Expand Down
5 changes: 5 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
li->inlining = 2;
else if (ma == (jl_value_t*)jl_propagate_inbounds_sym)
li->propagate_inbounds = 1;
else if (ma == (jl_value_t*)jl_noinfer_sym)
li->noinfer = 1;
else if (ma == (jl_value_t*)jl_aggressive_constprop_sym)
li->constprop = 1;
else if (ma == (jl_value_t*)jl_no_constprop_sym)
Expand Down Expand Up @@ -476,6 +478,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->inferred = 0;
src->propagate_inbounds = 0;
src->has_fcall = 0;
src->noinfer = 0;
src->edges = jl_nothing;
src->constprop = 0;
src->inlining = 0;
Expand Down Expand Up @@ -679,6 +682,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
}
}
m->called = called;
m->noinfer = src->noinfer;
m->constprop = src->constprop;
m->purity.bits = src->purity.bits;
jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name);
Expand Down Expand Up @@ -808,6 +812,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
m->primary_world = 1;
m->deleted_world = ~(size_t)0;
m->is_for_opaque_closure = 0;
m->noinfer = 0;
m->constprop = 0;
JL_MUTEX_INIT(&m->writelock);
return m;
Expand Down
11 changes: 10 additions & 1 deletion stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ const TAGS = Any[
const NTAGS = length(TAGS)
@assert NTAGS == 255

const ser_version = 23 # do not make changes without bumping the version #!
const ser_version = 24 # do not make changes without bumping the version #!

format_version(::AbstractSerializer) = ser_version
format_version(s::Serializer) = s.version
Expand Down Expand Up @@ -418,6 +418,7 @@ function serialize(s::AbstractSerializer, meth::Method)
serialize(s, meth.nargs)
serialize(s, meth.isva)
serialize(s, meth.is_for_opaque_closure)
serialize(s, meth.noinfer)
serialize(s, meth.constprop)
serialize(s, meth.purity)
if isdefined(meth, :source)
Expand Down Expand Up @@ -1026,10 +1027,14 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
nargs = deserialize(s)::Int32
isva = deserialize(s)::Bool
is_for_opaque_closure = false
noinfer = false
constprop = purity = 0x00
template_or_is_opaque = deserialize(s)
if isa(template_or_is_opaque, Bool)
is_for_opaque_closure = template_or_is_opaque
if format_version(s) >= 24
noinfer = deserialize(s)::Bool
end
if format_version(s) >= 14
constprop = deserialize(s)::UInt8
end
Expand All @@ -1054,6 +1059,7 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
meth.nargs = nargs
meth.isva = isva
meth.is_for_opaque_closure = is_for_opaque_closure
meth.noinfer = noinfer
meth.constprop = constprop
meth.purity = purity
if template !== nothing
Expand Down Expand Up @@ -1195,6 +1201,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
if format_version(s) >= 20
ci.has_fcall = deserialize(s)
end
if format_version(s) >= 24
ci.noinfer = deserialize(s)::Bool
end
if format_version(s) >= 21
ci.inlining = deserialize(s)::UInt8
end
Expand Down
Loading

0 comments on commit f293f89

Please sign in to comment.