Skip to content

use src.nargs for validate_code! #58327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Compiler/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstanc
mnargs = 0
else
m = mi.def::Method
mnargs = m.nargs
mnargs = Int(m.nargs)
n_sig_params = length((unwrap_unionall(m.sig)::DataType).parameters)
if m.is_for_opaque_closure
m.sig === Tuple || push!(errors, InvalidCodeError(INVALID_SIGNATURE_OPAQUE_CLOSURE, (m.sig, m.isva)))
Expand All @@ -234,6 +234,7 @@ function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstanc
end
end
if isa(c, CodeInfo)
mnargs = Int(c.nargs)
mnargs > length(c.slotnames) && push!(errors, InvalidCodeError(SLOTNAMES_NARGS_MISMATCH))
validate_code!(errors, c, is_top_level)
end
Expand Down
53 changes: 39 additions & 14 deletions Compiler/test/contextual.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module contextual

# N.B.: This file is also run from interpreter.jl, so needs to be standalone-executable
using Test

include("setup_Compiler.jl")

# Cassette
# ========

# TODO Use CassetteBase.jl instead of this mini-cassette?

module MiniCassette
# A minimal demonstration of the cassette mechanism. Doesn't support all the
# fancy features, but sufficient to exercise this code path in the compiler.

using Core: SimpleVector
using Core.IR
using ..Compiler
using ..Compiler: retrieve_code_info, quoted, anymap
using Base: Compiler as CC
using .CC: retrieve_code_info, quoted, anymap
using Base.Meta: isexpr

export Ctx, overdub

struct Ctx; end

# A no-op cassette-like transform
function transform_expr(expr, map_slot_number, map_ssa_value, sparams::Core.SimpleVector)
function transform_expr(expr, map_slot_number, map_ssa_value, sparams::SimpleVector)
@nospecialize expr
transform(@nospecialize expr) = transform_expr(expr, map_slot_number, map_ssa_value, sparams)
if isexpr(expr, :call)
Expand All @@ -46,11 +49,11 @@ module MiniCassette
end
end

function transform!(mi::MethodInstance, ci::CodeInfo, nargs::Int, sparams::Core.SimpleVector)
function transform!(mi::MethodInstance, ci::CodeInfo, nargs::Int, sparams::SimpleVector)
code = ci.code
di = Compiler.DebugInfoStream(mi, ci.debuginfo, length(code))
ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+1:end]...]
ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+1:end]...]
di = CC.DebugInfoStream(mi, ci.debuginfo, length(code))
ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+2:end]...]
ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+2:end]...]
# Insert one SSAValue for every argument statement
prepend!(code, Any[Expr(:call, getfield, SlotNumber(4), i) for i = 1:nargs])
prepend!(di.codelocs, fill(Int32(0), 3nargs))
Expand All @@ -77,31 +80,48 @@ module MiniCassette

function overdub_generator(world::UInt, source, self, ctx, f, args)
@nospecialize
argnames = Core.svec(:overdub, :ctx, :f, :args)
spnames = Core.svec()

if !Base.issingletontype(f)
# (c, f, args..) -> f(args...)
ex = :(return f(args...))
return Core.GeneratedFunctionStub(identity, Core.svec(:overdub, :ctx, :f, :args), Core.svec())(world, source, ex)
return generate_lambda_ex(world, source, argnames, spnames, :(return f(args...)))
end

tt = Tuple{f, args...}
match = Base._which(tt; world)
mi = Base.specialize_method(match)
# Unsupported in this mini-cassette
@assert !mi.def.isva
!mi.def.isva ||
return generate_lambda_ex(world, source, argnames, spnames, :(error("Unsupported vararg method")))
src = retrieve_code_info(mi, world)
@assert isa(src, CodeInfo)
isa(src, CodeInfo) ||
return generate_lambda_ex(world, source, argnames, spnames, :(error("Unexpected code transformation")))
src = copy(src)
@assert src.edges === Core.svec()
src.edges === Core.svec() ||
return generate_lambda_ex(world, source, argnames, spnames, :(error("Unexpected code transformation")))
src.edges = Any[mi]
transform!(mi, src, length(args), match.sparams)
# TODO: this is mandatory: code_info.min_world = max(code_info.min_world, min_world[])
# TODO: this is mandatory: code_info.max_world = min(code_info.max_world, max_world[])
# Match the generator, since that's what our transform! does
src.nargs = 4
src.isva = true
errors = CC.validate_code(mi, src)
if !isempty(errors)
foreach(Core.println, errors)
return generate_lambda_ex(world, source, argnames, spnames, :(error("Found errors in generated code")))
end
return src
end

function generate_lambda_ex(world::UInt, source::Method,
argnames::SimpleVector, spnames::SimpleVector,
body::Expr)
stub = Core.GeneratedFunctionStub(identity, argnames, spnames)
return stub(world, source, body)
end

@inline overdub(::Ctx, f::Union{Core.Builtin, Core.IntrinsicFunction}, args...) = f(args...)

@eval function overdub(ctx::Ctx, f, args...)
Expand All @@ -125,3 +145,8 @@ f() = 2
foo(i) = i+bar(Val(1))

@test @inferred(overdub(Ctx(), foo, 1)) == 43

morethan4args(a, b, c, d, e) = (((a + b) + c) + d) + e
@test overdub(Ctx(), morethan4args, 1, 2, 3, 4, 5) == 15

end # module contextual
34 changes: 9 additions & 25 deletions Compiler/test/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ function f22938(a, b, x...)
return i * a
end

msig = Tuple{typeof(f22938),Int,Int,Int,Int}
world = Base.get_world_counter()
match = only(Base._methods_by_ftype(msig, -1, world))
mi = Compiler.specialize_method(match)
c0 = Compiler.retrieve_code_info(mi, world)

@test isempty(Compiler.validate_code(mi, c0))
const c0 = let
msig = Tuple{typeof(f22938),Int,Int,Int,Int}
world = Base.get_world_counter()
match = only(Base._methods_by_ftype(msig, -1, world))
mi = Compiler.specialize_method(match)
c0 = Compiler.retrieve_code_info(mi, world)
@test isempty(Compiler.validate_code(mi, c0))
c0
end

@testset "INVALID_EXPR_HEAD" begin
c = copy(c0)
Expand Down Expand Up @@ -114,28 +116,10 @@ end
@test errors[1].kind === Compiler.SSAFLAGS_MISMATCH
end

@testset "SIGNATURE_NARGS_MISMATCH" begin
old_sig = mi.def.sig
mi.def.sig = Tuple{1,2}
errors = Compiler.validate_code(mi, nothing)
mi.def.sig = old_sig
@test length(errors) == 1
@test errors[1].kind === Compiler.SIGNATURE_NARGS_MISMATCH
end

@testset "NON_TOP_LEVEL_METHOD" begin
c = copy(c0)
c.code[1] = Expr(:method, :dummy)
errors = Compiler.validate_code(c)
@test length(errors) == 1
@test errors[1].kind === Compiler.NON_TOP_LEVEL_METHOD
end

@testset "SLOTNAMES_NARGS_MISMATCH" begin
mi.def.nargs += 20
errors = Compiler.validate_code(mi, c0)
mi.def.nargs -= 20
@test length(errors) == 2
@test count(e.kind === Compiler.SLOTNAMES_NARGS_MISMATCH for e in errors) == 1
@test count(e.kind === Compiler.SIGNATURE_NARGS_MISMATCH for e in errors) == 1
end