Skip to content

Commit 52bafeb

Browse files
committed
Extend PartialTuple for structs
This aims to address the following issue: Say we have: ``` function foo(b) a = 1 f = ()->Val(a) f() end ``` This infers beatifully, because the closure struct gets `Const`'ed and thus inter-procedural constant prop takes care of it. However, if we instead do ``` function foo(b) a = 1 f = ()->(println(b); Val(a)) f() end ``` the best inference can tell us about this is that we get `Val`, because the captured arguments are no longer constant. This leads to significant inference problems for Zygote, because the backwards pass is always specified as a closure. Thus, even if constant information is present in the forward pass, it is often lost for the part of the function that's the backwards pass, because it has to pass through the closure struct. This fixes that by keeping track of field types individually.
1 parent e5e1253 commit 52bafeb

File tree

8 files changed

+115
-33
lines changed

8 files changed

+115
-33
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,21 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
149149
return rettype
150150
end
151151

152+
153+
function const_prop_profitable(arg)
154+
# have new information from argtypes that wasn't available from the signature
155+
if isa(arg, PartialStruct)
156+
for b in arg.fields
157+
isconstType(b) && return true
158+
const_prop_profitable(b) && return true
159+
end
160+
elseif !isa(arg, Const) || (isa(arg.val, Symbol) || isa(arg.val, Type) || (!isa(arg.val, String) && isimmutable(arg.val)))
161+
# don't consider mutable values or Strings useful constants
162+
return true
163+
end
164+
return false
165+
end
166+
152167
function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::SimpleVector, sv::InferenceState)
153168
method = match[3]::Method
154169
nargs::Int = method.nargs
@@ -158,12 +173,8 @@ function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecial
158173
for a in argtypes
159174
a = widenconditional(a)
160175
if has_nontrivial_const_info(a)
161-
# have new information from argtypes that wasn't available from the signature
162-
if !isa(a, Const) || (isa(a.val, Symbol) || isa(a.val, Type) || (!isa(a.val, String) && isimmutable(a.val)))
163-
# don't consider mutable values or Strings useful constants
164-
haveconst = true
165-
break
166-
end
176+
haveconst = const_prop_profitable(a)
177+
haveconst && break
167178
end
168179
end
169180
haveconst || improvable_via_constant_propagation(rettype) || return Any
@@ -189,7 +200,7 @@ function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecial
189200
# in this case, see if all of the arguments are constants
190201
for a in argtypes
191202
a = widenconditional(a)
192-
if !isa(a, Const) && !isconstType(a)
203+
if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct)
193204
return Any
194205
end
195206
end
@@ -384,7 +395,7 @@ end
384395
# Union of Tuples of the same length is converted to Tuple of Unions.
385396
# returns an array of types
386397
function precise_container_type(@nospecialize(typ), vtypes::VarTable, sv::InferenceState)
387-
if isa(typ, PartialTuple)
398+
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
388399
return typ.fields
389400
end
390401

@@ -498,8 +509,9 @@ end
498509
# do apply(af, fargs...), where af is a function value
499510
function abstract_apply(@nospecialize(aft), aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState,
500511
max_methods = sv.params.MAX_METHODS)
501-
if !isa(aft, Const) && (!isType(aft) || has_free_typevars(aft))
502-
if !isconcretetype(aft) || (aft <: Builtin)
512+
aftw = widenconst(aft)
513+
if !isa(aft, Const) && (!isType(aftw) || has_free_typevars(aftw))
514+
if !isconcretetype(aftw) || (aftw <: Builtin)
503515
# non-constant function of unknown type: bail now,
504516
# since it seems unlikely that abstract_call will be able to do any better after splitting
505517
# this also ensures we don't call abstract_call_gf_by_type below on an IntrinsicFunction or Builtin
@@ -891,26 +903,37 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
891903
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
892904
if isconcretetype(t) && !t.mutable
893905
args = Vector{Any}(undef, length(e.args)-1)
894-
isconst = true
906+
ats = Vector{Any}(undef, length(e.args)-1)
907+
anyconst = false
908+
allconst = true
895909
for i = 2:length(e.args)
896910
at = abstract_eval(e.args[i], vtypes, sv)
911+
if !anyconst
912+
anyconst = has_nontrivial_const_info(at)
913+
end
914+
ats[i-1] = at
897915
if at === Bottom
898916
t = Bottom
899-
isconst = false
917+
allconst = anyconst = false
900918
break
901919
elseif at isa Const
902920
if !(at.val isa fieldtype(t, i - 1))
903921
t = Bottom
904-
isconst = false
922+
allconst = anyconst = false
905923
break
906924
end
907925
args[i-1] = at.val
908926
else
909-
isconst = false
927+
allconst = false
910928
end
911929
end
912-
if isconst
913-
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
930+
# For now, don't allow partially initialized Const/PartialStruct
931+
if t !== Bottom && fieldcount(t) == length(ats)
932+
if allconst
933+
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
934+
elseif anyconst
935+
t = PartialStruct(t, ats)
936+
end
914937
end
915938
end
916939
elseif e.head === :splatnew
@@ -1077,7 +1100,7 @@ function typeinf_local(frame::InferenceState)
10771100
elseif hd === :return
10781101
pc´ = n + 1
10791102
rt = widenconditional(abstract_eval(stmt.args[1], s[pc], frame))
1080-
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialTuple)
1103+
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
10811104
# only propagate information we know we can store
10821105
# and is valid inter-procedurally
10831106
rt = widenconst(rt)

base/compiler/inferenceresult.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
function is_argtype_match(@nospecialize(given_argtype),
2323
@nospecialize(cache_argtype),
2424
overridden_by_const::Bool)
25-
if isa(given_argtype, Const) || isa(given_argtype, PartialTuple)
25+
if isa(given_argtype, Const) || isa(given_argtype, PartialStruct)
2626
return is_lattice_equal(given_argtype, cache_argtype)
2727
end
2828
return !overridden_by_const
@@ -66,7 +66,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
6666
nargs::Int = toplevel ? 0 : linfo.def.nargs
6767
cache_argtypes = Vector{Any}(undef, nargs)
6868
# First, if we're dealing with a varargs method, then we set the last element of `args`
69-
# to the appropriate `Tuple` type or `PartialTuple` instance.
69+
# to the appropriate `Tuple` type or `PartialStruct` instance.
7070
if !toplevel && linfo.def.isva
7171
if linfo.specTypes == Tuple
7272
if nargs > 1

base/compiler/ssair/inlining.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,8 @@ function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, at
583583
for i in 3:length(argexprs)
584584
def = argexprs[i]
585585
def_type = atypes[i]
586-
if def_type isa PartialTuple
586+
if def_type isa PartialStruct
587+
# def_type.typ <: Tuple is assumed
587588
def_atypes = def_type.fields
588589
else
589590
def_atypes = Any[]

base/compiler/tfuncs.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,9 +716,12 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
716716
end
717717
end
718718
s = typeof(sv)
719-
elseif isa(s, PartialTuple)
719+
elseif isa(s, PartialStruct)
720720
if isa(name, Const)
721721
nv = name.val
722+
if isa(nv, Symbol)
723+
nv = fieldindex(widenconst(s), nv, false)
724+
end
722725
if isa(nv, Int) && 1 <= nv <= length(s.fields)
723726
return s.fields[nv]
724727
end
@@ -1139,7 +1142,7 @@ function tuple_tfunc(atypes::Vector{Any})
11391142
typ = Tuple{params...}
11401143
# replace a singleton type with its equivalent Const object
11411144
isdefined(typ, :instance) && return Const(typ.instance)
1142-
return anyinfo ? PartialTuple(typ, atypes) : typ
1145+
return anyinfo ? PartialStruct(typ, atypes) : typ
11431146
end
11441147

11451148
function array_type_undefable(@nospecialize(a))

base/compiler/typelattice.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct StateUpdate
7070
state::VarTable
7171
end
7272

73-
struct PartialTuple
73+
struct PartialStruct
7474
typ
7575
fields::Vector{Any} # elements are other type lattice members
7676
end
@@ -125,8 +125,8 @@ function ⊑(@nospecialize(a), @nospecialize(b))
125125
elseif isa(b, Conditional)
126126
return false
127127
end
128-
if isa(a, PartialTuple)
129-
if isa(b, PartialTuple)
128+
if isa(a, PartialStruct)
129+
if isa(b, PartialStruct)
130130
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
131131
return false
132132
end
@@ -137,9 +137,15 @@ function ⊑(@nospecialize(a), @nospecialize(b))
137137
return true
138138
end
139139
return isa(b, Type) && a.typ <: b
140-
elseif isa(b, PartialTuple)
140+
elseif isa(b, PartialStruct)
141141
if isa(a, Const)
142142
nfields(a.val) == length(b.fields) || return false
143+
widenconst(b).name === widenconst(a).name || return false
144+
# We can skip the subtype check if b is a Tuple, since in that
145+
# case, the ⊑ of the elements is sufficient.
146+
if b.typ.name !== Tuple.name && !(widenconst(a) <: widenconst(b))
147+
return false
148+
end
143149
for i in 1:nfields(a.val)
144150
# XXX: let's handle varargs later
145151
(Const(getfield(a.val, i)), b.fields[i]) || return false
@@ -173,15 +179,16 @@ end
173179
# `a ⊑ b && b ⊑ a` but with extra performance optimizations.
174180
function is_lattice_equal(@nospecialize(a), @nospecialize(b))
175181
a === b && return true
176-
if isa(a, PartialTuple)
177-
isa(b, PartialTuple) || return false
182+
if isa(a, PartialStruct)
183+
isa(b, PartialStruct) || return false
178184
length(a.fields) == length(b.fields) || return false
185+
widenconst(a) == widenconst(b) || return false
179186
for i in 1:length(a.fields)
180187
is_lattice_equal(a.fields[i], b.fields[i]) || return false
181188
end
182189
return true
183190
end
184-
isa(b, PartialTuple) && return false
191+
isa(b, PartialStruct) && return false
185192
a isa Const && return false
186193
b isa Const && return false
187194
return a b && b a
@@ -200,7 +207,7 @@ function widenconst(c::Const)
200207
end
201208
widenconst(m::MaybeUndef) = widenconst(m.typ)
202209
widenconst(c::PartialTypeVar) = TypeVar
203-
widenconst(t::PartialTuple) = t.typ
210+
widenconst(t::PartialStruct) = t.typ
204211
widenconst(@nospecialize(t)) = t
205212

206213
issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)

base/compiler/typelimits.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,27 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
317317
end
318318
return Bool
319319
end
320+
if (isa(typea, PartialStruct) || isa(typea, Const)) &&
321+
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
322+
widenconst(typea) === widenconst(typeb)
323+
324+
typea_nfields = nfields_tfunc(typea)
325+
typeb_nfields = nfields_tfunc(typeb)
326+
if !isa(typea_nfields, Const) || !isa(typea_nfields, Const) || typea_nfields.val !== typeb_nfields.val
327+
return widenconst(typea)
328+
end
329+
330+
type_nfields = typea_nfields.val::Int
331+
fields = Vector{Any}(undef, type_nfields)
332+
anyconst = false
333+
for i = 1:type_nfields
334+
fields[i] = tmerge(getfield_tfunc(typea, Const(i)),
335+
getfield_tfunc(typeb, Const(i)))
336+
anyconst |= has_nontrivial_const_info(fields[i])
337+
end
338+
return anyconst ? PartialStruct(widenconst(typea), fields) :
339+
widenconst(typea)
340+
end
320341
# no special type-inference lattice, join the types
321342
typea, typeb = widenconst(typea), widenconst(typeb)
322343
typea === typeb && return typea

base/compiler/typeutils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function issingletontype(@nospecialize t)
3333
end
3434

3535
function has_nontrivial_const_info(@nospecialize t)
36-
isa(t, PartialTuple) && return true
36+
isa(t, PartialStruct) && return true
3737
return isa(t, Const) && !isdefined(typeof(t.val), :instance) && !(isa(t.val, Type) && issingletontype(t.val))
3838
end
3939

test/compiler/inference.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,8 +1338,8 @@ let egal_tfunc
13381338
@test egal_tfunc(Union{Int64, Float64}, AbstractArray) === Const(false)
13391339
end
13401340

1341-
using Core.Compiler: PartialTuple, nfields_tfunc, sizeof_tfunc, sizeof_nothrow
1342-
let PT = PartialTuple(Tuple{Int64,UInt64}, Any[Const(10, false), UInt64])
1341+
using Core.Compiler: PartialStruct, nfields_tfunc, sizeof_tfunc, sizeof_nothrow
1342+
let PT = PartialStruct(Tuple{Int64,UInt64}, Any[Const(10, false), UInt64])
13431343
@test sizeof_tfunc(PT) === Const(16, false)
13441344
@test nfields_tfunc(PT) === Const(2, false)
13451345
@test sizeof_nothrow(PT) === true
@@ -2261,3 +2261,30 @@ f_incr(x::Tuple, y::Tuple, args...) = f_incr((x, y), args...)
22612261
f_incr(x::Tuple) = x
22622262
@test @inferred(f_incr((), (), (), (), (), (), (), ())) ==
22632263
((((((((), ()), ()), ()), ()), ()), ()), ())
2264+
2265+
# Test PartialStruct for closures
2266+
@noinline use30783(x) = nothing
2267+
function foo30783(b)
2268+
a = 1
2269+
f = ()->(use30783(b); Val(a))
2270+
f()
2271+
end
2272+
@test @inferred(foo30783(2)) == Val(1)
2273+
2274+
# PartialStruct tmerge
2275+
using Core.Compiler: PartialStruct, tmerge, Const,
2276+
struct FooPartial
2277+
a::Int
2278+
b::Int
2279+
c::Int
2280+
end
2281+
let PT1 = PartialStruct(FooPartial, Any[Const(1), Const(2), Int]),
2282+
PT2 = PartialStruct(FooPartial, Any[Const(1), Int, Int]),
2283+
PT3 = PartialStruct(FooPartial, Any[Const(1), Int, Const(3)])
2284+
2285+
@test PT1 PT2
2286+
@test !(PT1 PT3) && !(PT2 PT1)
2287+
let (==) = (a, b)->(a b && b a)
2288+
@test tmerge(PT1, PT3) == PT2
2289+
end
2290+
end

0 commit comments

Comments
 (0)