@@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int,
241
241
push! (from_bbs, length (state. new_cfg_blocks))
242
242
# TODO : Right now we unconditionally generate a fallback block
243
243
# in case of subtyping errors - This is probably unnecessary.
244
- if i != length (cases) || (! fully_covered || ! params. trust_inference)
244
+ if i != length (cases) || (! fully_covered || ( ! params. trust_inference && isdispatchtuple (cases[i] . sig)) )
245
245
# This block will have the next condition or the final else case
246
246
push! (state. new_cfg_blocks, BasicBlock (StmtRange (idx, idx)))
247
247
push! (state. new_cfg_blocks[cond_bb]. succs, length (state. new_cfg_blocks))
@@ -481,7 +481,8 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
481
481
cond = true
482
482
aparams, mparams = atype. parameters:: SimpleVector , metharg. parameters:: SimpleVector
483
483
@assert length (aparams) == length (mparams)
484
- if i != length (cases) || ! fully_covered || ! params. trust_inference
484
+ if i != length (cases) || ! fully_covered ||
485
+ (! params. trust_inference && isdispatchtuple (cases[i]. sig))
485
486
for i in 1 : length (aparams)
486
487
a, m = aparams[i], mparams[i]
487
488
# If this is always true, we don't need to check for it
@@ -538,7 +539,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
538
539
bb += 1
539
540
# We're now in the fall through block, decide what to do
540
541
if fully_covered
541
- if ! params. trust_inference
542
+ if ! params. trust_inference && isdispatchtuple (cases[ end ] . sig)
542
543
e = Expr (:call , GlobalRef (Core, :throw ), FATAL_TYPE_BOUND_ERROR)
543
544
insert_node_here! (compact, NewInstruction (e, Union{}, line))
544
545
insert_node_here! (compact, NewInstruction (ReturnNode (), Union{}, line))
@@ -1170,7 +1171,10 @@ function analyze_single_call!(
1170
1171
cases = InliningCase[]
1171
1172
local only_method = nothing # keep track of whether there is one matching method
1172
1173
local meth:: MethodLookupResult
1173
- local fully_covered = true
1174
+ local handled_all_cases = true
1175
+ local any_covers_full = false
1176
+ local revisit_idx = nothing
1177
+
1174
1178
for i in 1 : length (infos)
1175
1179
meth = infos[i]. results
1176
1180
if meth. ambig
@@ -1179,7 +1183,7 @@ function analyze_single_call!(
1179
1183
return nothing
1180
1184
elseif length (meth) == 0
1181
1185
# No applicable methods; try next union split
1182
- fully_covered = false
1186
+ handled_all_cases = false
1183
1187
continue
1184
1188
else
1185
1189
if length (meth) == 1 && only_method != = false
@@ -1192,12 +1196,38 @@ function analyze_single_call!(
1192
1196
only_method = false
1193
1197
end
1194
1198
end
1195
- for match in meth
1196
- fully_covered &= handle_match! (match, argtypes, flag, state, cases)
1197
- fully_covered &= match. fully_covers
1199
+ for (j, match) in enumerate (meth)
1200
+ any_covers_full |= match. fully_covers
1201
+ if ! isdispatchtuple (match. spec_types)
1202
+ if ! match. fully_covers
1203
+ handled_all_cases = false
1204
+ continue
1205
+ end
1206
+ if revisit_idx === nothing
1207
+ revisit_idx = (i, j)
1208
+ else
1209
+ handled_all_cases = false
1210
+ revisit_idx = nothing
1211
+ end
1212
+ else
1213
+ handled_all_cases &= handle_match! (match, argtypes, flag, state, cases)
1214
+ end
1198
1215
end
1199
1216
end
1200
1217
1218
+ # If there's only one case that's not a dispatchtuple, we can
1219
+ # still unionsplit by visiting all the other cases first.
1220
+ # This is useful for code like:
1221
+ # foo(x::Int) = 1
1222
+ # foo(@nospecialize(x::Any)) = 2
1223
+ # where we where only a small number of specific dispatchable
1224
+ # cases are split off from an ::Any typed fallback.
1225
+ if handled_all_cases && revisit_idx != = nothing
1226
+ (i, j) = revisit_idx
1227
+ match = infos[i]. results[j]
1228
+ handled_all_cases &= handle_match! (match, argtypes, flag, state, cases)
1229
+ end
1230
+
1201
1231
# if the signature is fully covered and there is only one applicable method,
1202
1232
# we can try to inline it even if the signature is not a dispatch tuple
1203
1233
atype = argtypes_to_type (argtypes)
@@ -1213,10 +1243,10 @@ function analyze_single_call!(
1213
1243
item = analyze_method! (match, argtypes, flag, state)
1214
1244
item === nothing && return nothing
1215
1245
push! (cases, InliningCase (match. spec_types, item))
1216
- fully_covered = match. fully_covers
1246
+ any_covers_full = handled_all_cases = match. fully_covers
1217
1247
end
1218
1248
1219
- handle_cases! (ir, idx, stmt, atype, cases, fully_covered , todo, state. params)
1249
+ handle_cases! (ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases , todo, state. params)
1220
1250
end
1221
1251
1222
1252
# similar to `analyze_single_call!`, but with constant results
@@ -1227,7 +1257,8 @@ function handle_const_call!(
1227
1257
(; call, results) = cinfo
1228
1258
infos = isa (call, MethodMatchInfo) ? MethodMatchInfo[call] : call. matches
1229
1259
cases = InliningCase[]
1230
- local fully_covered = true
1260
+ local handled_all_cases = true
1261
+ local any_covers_full = false
1231
1262
local j = 0
1232
1263
for i in 1 : length (infos)
1233
1264
meth = infos[i]. results
@@ -1237,22 +1268,22 @@ function handle_const_call!(
1237
1268
return nothing
1238
1269
elseif length (meth) == 0
1239
1270
# No applicable methods; try next union split
1240
- fully_covered = false
1271
+ handled_all_cases = false
1241
1272
continue
1242
1273
end
1243
1274
for match in meth
1244
1275
j += 1
1245
1276
result = results[j]
1277
+ any_covers_full |= match. fully_covers
1246
1278
if isa (result, ConstResult)
1247
1279
case = const_result_item (result, state)
1248
1280
push! (cases, InliningCase (result. mi. specTypes, case))
1249
1281
elseif isa (result, InferenceResult)
1250
- fully_covered &= handle_inf_result! (result, argtypes, flag, state, cases)
1282
+ handled_all_cases &= handle_inf_result! (result, argtypes, flag, state, cases)
1251
1283
else
1252
1284
@assert result === nothing
1253
- fully_covered &= handle_match! (match, argtypes, flag, state, cases)
1285
+ handled_all_cases &= isdispatchtuple (match . spec_types) && handle_match! (match, argtypes, flag, state, cases)
1254
1286
end
1255
- fully_covered &= match. fully_covers
1256
1287
end
1257
1288
end
1258
1289
@@ -1265,17 +1296,16 @@ function handle_const_call!(
1265
1296
validate_sparams (mi. sparam_vals) || return nothing
1266
1297
item === nothing && return nothing
1267
1298
push! (cases, InliningCase (mi. specTypes, item))
1268
- fully_covered = atype <: mi.specTypes
1299
+ any_covers_full = handled_all_cases = atype <: mi.specTypes
1269
1300
end
1270
1301
1271
- handle_cases! (ir, idx, stmt, atype, cases, fully_covered , todo, state. params)
1302
+ handle_cases! (ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases , todo, state. params)
1272
1303
end
1273
1304
1274
1305
function handle_match! (
1275
1306
match:: MethodMatch , argtypes:: Vector{Any} , flag:: UInt8 , state:: InliningState ,
1276
1307
cases:: Vector{InliningCase} )
1277
1308
spec_types = match. spec_types
1278
- isdispatchtuple (spec_types) || return false
1279
1309
item = analyze_method! (match, argtypes, flag, state)
1280
1310
item === nothing && return false
1281
1311
_any (case-> case. sig === spec_types, cases) && return true
0 commit comments