Skip to content

Commit

Permalink
throw runtime TypeError on invalid ccall symbol (JuliaLang#49142)
Browse files Browse the repository at this point in the history
Throw runtime `TypeError` on invalid ccall or cglobal symbol, rather
than throwing an internal compilation error.

Closes JuliaLang#49141
Closes JuliaLang#45187

Co-authored-by: Jameson Nash <vtjnash@gmail.com>
  • Loading branch information
Pangoraw and vtjnash authored Mar 29, 2023
1 parent bc33c81 commit 8f78a94
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 22 deletions.
42 changes: 28 additions & 14 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ static void typeassert_input(jl_codectx_t &ctx, const jl_cgval_t &jvinfo, jl_val
ctx.builder.CreateCondBr(istype, passBB, failBB);

ctx.builder.SetInsertPoint(failBB);
emit_type_error(ctx, mark_julia_type(ctx, vx, true, jl_any_type), boxed(ctx, jlto_runtime), msg);
just_emit_type_error(ctx, mark_julia_type(ctx, vx, true, jl_any_type), boxed(ctx, jlto_runtime), msg);
ctx.builder.CreateUnreachable();
ctx.builder.SetInsertPoint(passBB);
}
Expand Down Expand Up @@ -568,8 +568,15 @@ typedef struct {
jl_value_t *gcroot;
} native_sym_arg_t;

static inline const char *invalid_symbol_err_msg(bool ccall)
{
return ccall ?
"ccall: first argument not a pointer or valid constant expression" :
"cglobal: first argument not a pointer or valid constant expression";
}

// --- parse :sym or (:sym, :lib) argument into address info ---
static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_value_t *arg, const char *fname, bool llvmcall)
static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_value_t *arg, bool ccall, bool llvmcall)
{
Value *&jl_ptr = out.jl_ptr;
void (*&fptr)(void) = out.fptr;
Expand Down Expand Up @@ -599,9 +606,7 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va
jl_cgval_t arg1 = emit_expr(ctx, arg);
jl_value_t *ptr_ty = arg1.typ;
if (!jl_is_cpointer_type(ptr_ty)) {
const char *errmsg = !strcmp(fname, "ccall") ?
"ccall: first argument not a pointer or valid constant expression" :
"cglobal: first argument not a pointer or valid constant expression";
const char *errmsg = invalid_symbol_err_msg(ccall);
emit_cpointercheck(ctx, arg1, errmsg);
}
arg1 = update_julia_type(ctx, arg1, (jl_value_t*)jl_voidpointer_type);
Expand Down Expand Up @@ -647,19 +652,14 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va
f_name = jl_symbol_name((jl_sym_t*)t0);
else if (jl_is_string(t0))
f_name = jl_string_data(t0);
else
JL_TYPECHKS(fname, symbol, t0);

jl_value_t *t1 = jl_fieldref(ptr, 1);
if (jl_is_symbol(t1))
f_lib = jl_symbol_name((jl_sym_t*)t1);
else if (jl_is_string(t1))
f_lib = jl_string_data(t1);
else
JL_TYPECHKS(fname, symbol, t1);
}
else {
JL_TYPECHKS(fname, pointer, ptr);
f_name = NULL;
}
}
}
Expand Down Expand Up @@ -696,7 +696,15 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
Type *lrt = ctx.types().T_size;
assert(lrt == julia_type_to_llvm(ctx, rt));

interpret_symbol_arg(ctx, sym, args[1], "cglobal", false);
interpret_symbol_arg(ctx, sym, args[1], /*ccall=*/false, false);

if (sym.f_name == NULL && sym.fptr == NULL && sym.jl_ptr == NULL && sym.gcroot != NULL) {
const char *errmsg = invalid_symbol_err_msg(/*ccall=*/false);
jl_cgval_t arg1 = emit_expr(ctx, args[1]);
emit_type_error(ctx, arg1, literal_pointer_val(ctx, (jl_value_t *)jl_pointer_type), errmsg);
JL_GC_POP();
return jl_cgval_t();
}

if (sym.jl_ptr != NULL) {
res = ctx.builder.CreateBitCast(sym.jl_ptr, lrt);
Expand Down Expand Up @@ -1346,14 +1354,20 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
bool llvmcall = false;
std::tie(cc, llvmcall) = convert_cconv(cc_sym);

interpret_symbol_arg(ctx, symarg, args[1], "ccall", llvmcall);
interpret_symbol_arg(ctx, symarg, args[1], /*ccall=*/true, llvmcall);
Value *&jl_ptr = symarg.jl_ptr;
void (*&fptr)(void) = symarg.fptr;
const char *&f_name = symarg.f_name;
const char *&f_lib = symarg.f_lib;

if (f_name == NULL && fptr == NULL && jl_ptr == NULL) {
emit_error(ctx, "ccall: null function pointer");
if (symarg.gcroot != NULL) { // static_eval(ctx, args[1]) could not be interpreted to a function pointer
const char *errmsg = invalid_symbol_err_msg(/*ccall=*/true);
jl_cgval_t arg1 = emit_expr(ctx, args[1]);
emit_type_error(ctx, arg1, literal_pointer_val(ctx, (jl_value_t *)jl_pointer_type), errmsg);
} else {
emit_error(ctx, "ccall: null function pointer");
}
JL_GC_POP();
return jl_cgval_t();
}
Expand Down
17 changes: 11 additions & 6 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,13 +1352,21 @@ static Value *emit_typeof(jl_codectx_t &ctx, Value *v, bool maybenull)
}


static void emit_type_error(jl_codectx_t &ctx, const jl_cgval_t &x, Value *type, const std::string &msg)
static void just_emit_type_error(jl_codectx_t &ctx, const jl_cgval_t &x, Value *type, const std::string &msg)
{
Value *msg_val = stringConstPtr(ctx.emission_context, ctx.builder, msg);
ctx.builder.CreateCall(prepare_call(jltypeerror_func),
{ msg_val, maybe_decay_untracked(ctx, type), mark_callee_rooted(ctx, boxed(ctx, x))});
}

static void emit_type_error(jl_codectx_t &ctx, const jl_cgval_t &x, Value *type, const std::string &msg)
{
just_emit_type_error(ctx, x, type, msg);
ctx.builder.CreateUnreachable();
BasicBlock *cont = BasicBlock::Create(ctx.builder.getContext(), "after_type_error", ctx.f);
ctx.builder.SetInsertPoint(cont);
}

// Should agree with `emit_isa` below
static bool _can_optimize_isa(jl_value_t *type, int &counter)
{
Expand Down Expand Up @@ -1441,9 +1449,6 @@ static std::pair<Value*, bool> emit_isa(jl_codectx_t &ctx, const jl_cgval_t &x,
if (known_isa) {
if (!*known_isa && msg) {
emit_type_error(ctx, x, literal_pointer_val(ctx, type), *msg);
ctx.builder.CreateUnreachable();
BasicBlock *failBB = BasicBlock::Create(ctx.builder.getContext(), "fail", ctx.f);
ctx.builder.SetInsertPoint(failBB);
}
return std::make_pair(ConstantInt::get(getInt1Ty(ctx.builder.getContext()), *known_isa), true);
}
Expand Down Expand Up @@ -1581,7 +1586,7 @@ static void emit_typecheck(jl_codectx_t &ctx, const jl_cgval_t &x, jl_value_t *t
ctx.builder.CreateCondBr(istype, passBB, failBB);
ctx.builder.SetInsertPoint(failBB);

emit_type_error(ctx, x, literal_pointer_val(ctx, type), msg);
just_emit_type_error(ctx, x, literal_pointer_val(ctx, type), msg);
ctx.builder.CreateUnreachable();

ctx.f->getBasicBlockList().push_back(passBB);
Expand Down Expand Up @@ -3464,7 +3469,7 @@ static void emit_cpointercheck(jl_codectx_t &ctx, const jl_cgval_t &x, const std
ctx.builder.CreateCondBr(istype, passBB, failBB);
ctx.builder.SetInsertPoint(failBB);

emit_type_error(ctx, x, literal_pointer_val(ctx, (jl_value_t*)jl_pointer_type), msg);
just_emit_type_error(ctx, x, literal_pointer_val(ctx, (jl_value_t*)jl_pointer_type), msg);
ctx.builder.CreateUnreachable();

ctx.f->getBasicBlockList().push_back(passBB);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,14 @@ JL_DLLEXPORT jl_value_t *jl_cglobal(jl_value_t *v, jl_value_t *ty)

char *f_lib = NULL;
if (jl_is_tuple(v) && jl_nfields(v) > 1) {
jl_value_t *t1 = jl_fieldref_noalloc(v, 1);
v = jl_fieldref(v, 0);
jl_value_t *t1 = jl_fieldref(v, 1);
if (jl_is_symbol(t1))
f_lib = jl_symbol_name((jl_sym_t*)t1);
else if (jl_is_string(t1))
f_lib = jl_string_data(t1);
else
JL_TYPECHK(cglobal, symbol, t1)
v = jl_fieldref(v, 0);
}

char *f_name = NULL;
Expand Down
16 changes: 16 additions & 0 deletions test/ccall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,12 @@ end
@test_throws(ErrorException("ccall return type struct fields cannot contain a reference"),
@eval ccall(:fn, typeof(Ref("")), ()))

fn45187() = nothing

@test_throws(TypeError, @eval ccall(nothing, Cvoid, ()))
@test_throws(TypeError, @eval ccall(49142, Cvoid, ()))
@test_throws(TypeError, @eval ccall((:fn, fn45187), Cvoid, ()))

# test for malformed syntax errors
@test Expr(:error, "more arguments than types for ccall") == Meta.lower(@__MODULE__, :(ccall(:fn, A, (), x)))
@test Expr(:error, "more arguments than types for ccall") == Meta.lower(@__MODULE__, :(ccall(:fn, A, (B,), x, y)))
Expand Down Expand Up @@ -1910,6 +1916,12 @@ end
function cglobal33413_literal_notype()
return cglobal(:sin)
end
function cglobal49142_nothing()
return cglobal(nothing)
end
function cglobal45187fn()
return cglobal((:fn, fn45187))
end
@test unsafe_load(cglobal33413_ptrvar()) == 1
@test unsafe_load(cglobal33413_ptrinline()) == 1
@test unsafe_load(cglobal33413_tupleliteral()) == 1
Expand All @@ -1918,6 +1930,10 @@ end
@test unsafe_load(convert(Ptr{Cint}, cglobal33413_tupleliteral_notype())) == 1
@test cglobal33413_literal() != C_NULL
@test cglobal33413_literal_notype() != C_NULL
@test_throws(TypeError, cglobal49142_nothing())
@test_throws(TypeError, cglobal45187fn())
@test_throws(TypeError, @eval cglobal(nothing))
@test_throws(TypeError, @eval cglobal((:fn, fn45187)))
end

@testset "ccall_effects" begin
Expand Down

0 comments on commit 8f78a94

Please sign in to comment.