From 4779826657f5654ec323961a7c2d79c4f94a39a5 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Sun, 9 Jan 2022 14:38:57 +0900 Subject: [PATCH] inference: follow up #43603, better `getfield_tfunc` impl (#43713) --- base/compiler/tfuncs.jl | 26 +++++++++++++------------- test/compiler/inference.jl | 1 + 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 04d698fe89ec0f..65a86f262d9b39 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -795,17 +795,13 @@ function getfield_tfunc(s00, name, order, boundscheck) end getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(s00, name, false) function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool) - s = unwrap_unionall(s00) - if isa(s, Union) - return tmerge(getfield_tfunc(rewrap_unionall(s.a, s00), name), - getfield_tfunc(rewrap_unionall(s.b, s00), name)) - elseif isa(s, Conditional) + if isa(s00, Conditional) return Bottom # Bool has no fields - elseif isa(s, Const) || isconstType(s) - if !isa(s, Const) - sv = s.parameters[1] + elseif isa(s00, Const) || isconstType(s00) + if !isa(s00, Const) + sv = s00.parameters[1] else - sv = s.val + sv = s00.val end if isa(name, Const) nv = name.val @@ -845,11 +841,15 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool return unwrapva(s00.fields[nv]) end end + else + s = unwrap_unionall(s00) end - if isType(s) || !isa(s, DataType) || isabstracttype(s) - return Any + if isa(s, Union) + return tmerge(_getfield_tfunc(rewrap_unionall(s.a, s00), name, setfield), + _getfield_tfunc(rewrap_unionall(s.b, s00), name, setfield)) end - s = s::DataType + isa(s, DataType) || return Any + isabstracttype(s) && return Any if s <: Tuple && !(Int <: widenconst(name)) return Bottom end @@ -873,7 +873,7 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool if !(_ts <: Tuple) return Any end - return getfield_tfunc(_ts, name) + return _getfield_tfunc(_ts, name, setfield) end ftypes = datatype_fieldtypes(s) nf = length(ftypes) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 35d4e87c736fdb..a42ffbbe84db03 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1728,6 +1728,7 @@ end @test setfield!_tfunc(Const(@__MODULE__), Const(:v), Int) === Union{} @test setfield!_tfunc(Const(@__MODULE__), Int, Int) === Union{} @test setfield!_tfunc(Module, Const(:v), Int) === Union{} +@test setfield!_tfunc(Union{Module,Base.RefValue{Any}}, Const(:v), Int) === Union{} @test setfield!_tfunc(ABCDconst, Const(:a), Any) === Union{} @test setfield!_tfunc(ABCDconst, Const(:b), Any) === Union{} @test setfield!_tfunc(ABCDconst, Const(:d), Any) === Union{}