Skip to content

Commit dab7f8a

Browse files
committed
Generalize to arbitrary args.
1 parent 6335f86 commit dab7f8a

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/compiler/compilation.jl

+34-2
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
361361
return OpaqueClosure{id, typeof(env), sig, rt}(env)
362362
end
363363

364+
# generated function `ccall`, working around the restriction that ccall type
365+
# tuples need to be literals. this relies on ccall internals...
366+
@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...)
367+
ex = quote end
368+
369+
rettyp = _rettyp.parameters[1]
370+
types = _types.parameters[1].parameters
371+
args = [:(vals[$i]) for i in 1:length(vals)]
372+
373+
# cconvert
374+
cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)]
375+
for (dst, typ, src) in zip(cconverted, types, args)
376+
append!(ex.args, (quote
377+
$dst = Base.cconvert($typ, $src)
378+
end).args)
379+
end
380+
381+
# unsafe_convert
382+
unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)]
383+
for (dst, typ, src) in zip(unsafe_converted, types, cconverted)
384+
append!(ex.args, (quote
385+
$dst = Base.unsafe_convert($typ, $src)
386+
end).args)
387+
end
388+
389+
call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0,
390+
QuoteNode(:ccall), unsafe_converted..., cconverted...)
391+
push!(ex.args, call)
392+
return ex
393+
end
394+
364395
# device-side call to an opaque closure
365-
function (oc::OpaqueClosure{F})(a, b) where F
396+
function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R}
366397
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
367398
assume(ptr != C_NULL)
368-
return ccall(ptr, Int, (Int, Int), a, b)
399+
#ccall(ptr, R, (A...), args...)
400+
generated_ccall(ptr, R, A, args...)
369401
end

test/execution.jl

+9-8
Original file line numberDiff line numberDiff line change
@@ -1110,21 +1110,22 @@ end
11101110

11111111
# basic closure, constructed from CodeInfo
11121112
let
1113-
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
1113+
ir, rettyp = only(Base.code_typed(*, (Int, Int, Int)))
11141114
oc = CUDA.OpaqueClosure(ir)
11151115

1116-
c = CuArray([0])
1117-
a = CuArray([1])
1118-
b = CuArray([2])
1116+
d = CuArray([1])
1117+
a = CuArray([2])
1118+
b = CuArray([3])
1119+
c = CuArray([4])
11191120

1120-
function kernel(oc, c, a, b)
1121+
function kernel(oc, d, a, b, c)
11211122
i = threadIdx().x
1122-
@inbounds c[i] = oc(a[i], b[i])
1123+
@inbounds d[i] = oc(a[i], b[i], c[i])
11231124
return
11241125
end
1125-
@cuda threads=1 kernel(oc, c, a, b)
1126+
@cuda threads=1 kernel(oc, d, a, b, c)
11261127

1127-
@test Array(c)[] == 3
1128+
@test Array(d)[] == 24
11281129
end
11291130

11301131
end

0 commit comments

Comments
 (0)