Skip to content

Commit 10d380b

Browse files
authored
Misc ABI fixups (EnzymeAD#1200)
1 parent 674f22c commit 10d380b

File tree

5 files changed

+46
-10
lines changed

5 files changed

+46
-10
lines changed

src/api.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,8 @@ end
452452
ET_InternalError = 5,
453453
ET_TypeDepthExceeded = 6,
454454
ET_MixedActivityError = 7,
455-
ET_IllegalReplaceFicticiousPHIs = 8
455+
ET_IllegalReplaceFicticiousPHIs = 8,
456+
ET_GetIndexError = 9
456457
)
457458

458459
function EnzymeTypeAnalyzerToString(typeanalyzer)

src/compiler.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,20 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
18451845
end
18461846
emit_error(b, nothing, msg2)
18471847
return C_NULL
1848+
elseif errtype == API.ET_GetIndexError
1849+
@assert B != C_NULL
1850+
B = IRBuilder(B)
1851+
msg5 = sprint() do io::IO
1852+
print(io, "Enzyme internal error\n")
1853+
print(io, msg, '\n')
1854+
if bt !== nothing
1855+
print(io,"\nCaused by:")
1856+
Base.show_backtrace(io, bt)
1857+
println(io)
1858+
end
1859+
end
1860+
emit_error(B, nothing, msg5)
1861+
return C_NULL
18481862
end
18491863
throw(AssertionError("Unknown errtype"))
18501864
end
@@ -3367,6 +3381,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
33673381
elseif T <: Active
33683382
isboxed = GPUCompiler.deserves_argbox(T′)
33693383
if isboxed
3384+
if is_split
3385+
msg = sprint() do io
3386+
println(io, "Unimplemented: Had active input arg needing a box in split mode")
3387+
println(io, T, " at index ", i)
3388+
println(io, TT)
3389+
end
3390+
throw(AssertionError(msg))
3391+
end
33703392
@assert !is_split
33713393
# TODO replace with better enzyme_zero
33723394
ptr = gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), activeNum)])

src/rules/llvmrules.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,14 +1033,15 @@ function get_binding_or_error_fwd(B, orig, gutils, normalR, shadowR)
10331033
err = emit_error(B, orig, "Enzyme: unhandled forward for jl_get_binding_or_error")
10341034
newo = new_from_original(gutils, orig)
10351035
API.moveBefore(newo, err, B)
1036-
normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing
10371036

1038-
if shadowR != C_NULL && normal !== nothing
1037+
if unsafe_load(shadowR) != C_NULL
1038+
valTys = API.CValueType[API.VT_Primal, API.VT_Primal]
1039+
args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])]
1040+
normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false)
10391041
width = get_width(gutils)
10401042
if width == 1
10411043
shadowres = normal
10421044
else
1043-
position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal)))
10441045
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))))
10451046
for idx in 1:width
10461047
shadowres = insert_value!(B, shadowres, normal, idx-1)
@@ -1058,13 +1059,14 @@ function get_binding_or_error_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
10581059
err = emit_error(B, orig, "Enzyme: unhandled augmented forward for jl_get_binding_or_error")
10591060
newo = new_from_original(gutils, orig)
10601061
API.moveBefore(newo, err, B)
1061-
normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing
1062-
if shadowR != C_NULL && normal !== nothing
1062+
if unsafe_load(shadowR) != C_NULL
1063+
valTys = API.CValueType[API.VT_Primal, API.VT_Primal]
1064+
args = [new_from_original(gutils, operands(orig)[1]), new_from_original(gutils, operands(orig)[2])]
1065+
normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, #=lookup=#false)
10631066
width = get_width(gutils)
10641067
if width == 1
10651068
shadowres = normal
10661069
else
1067-
position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal)))
10681070
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))))
10691071
for idx in 1:width
10701072
shadowres = insert_value!(B, shadowres, normal, idx-1)

src/rules/typerules.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,16 @@ function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C
2828
end
2929

3030
function i64_box_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8
31-
TT = TypeTree(API.DT_Pointer, LLVM.context(LLVM.Value(val)))
31+
val = LLVM.Instruction(val)
32+
TT = TypeTree(API.DT_Pointer, LLVM.context(val))
33+
if (direction & API.DOWN) != 0
34+
sub = TypeTree(unsafe_load(args))
35+
ctx = LLVM.context(val)
36+
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(val)))))
37+
maxSize = div(width(value_type(operands(val)[1]))+7, 8)
38+
shift!(sub, dl, 0, maxSize, 0)
39+
API.EnzymeMergeTypeTree(TT, sub)
40+
end
3241
only!(TT, -1)
3342
API.EnzymeMergeTypeTree(ret, TT)
3443
return UInt8(false)
@@ -202,4 +211,4 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C
202211
end
203212

204213
return UInt8(false)
205-
end
214+
end

src/typetree.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ end
5353

5454
function merge!(dst::TypeTree, src::TypeTree; consume=true)
5555
API.EnzymeMergeTypeTree(dst, src)
56-
LLVM.dispose(src)
56+
if consume
57+
LLVM.dispose(src)
58+
end
5759
return nothing
5860
end
5961

0 commit comments

Comments
 (0)