Skip to content

Commit cb3037f

Browse files
authored
Disable sort on integers (EnzymeAD#1207)
* Disable sort on integers * fixup
1 parent ab9c91a commit cb3037f

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

src/compiler.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,13 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
16721672
ip = API.EnzymeTypeAnalyzerToString(data)
16731673
sval = Base.unsafe_string(ip)
16741674
API.EnzymeStringFree(ip)
1675+
1676+
if isa(val, LLVM.Instruction)
1677+
mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false)
1678+
if mi !== nothing
1679+
msg *= "\n" * string(mi) * "\n"
1680+
end
1681+
end
16751682
throw(IllegalTypeAnalysisException(msg, sval, ir, bt))
16761683
elseif errtype == API.ET_NoType
16771684
@assert B != C_NULL

src/internal_rules.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,9 +486,9 @@ end
486486
function EnzymeRules.forward(
487487
::Const{typeof(sort!)},
488488
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
489-
xs::Duplicated;
489+
xs::Duplicated{T};
490490
kwargs...
491-
)
491+
) where {T <: AbstractArray{<:AbstractFloat}}
492492
inds = sortperm(xs.val; kwargs...)
493493
xs.val .= xs.val[inds]
494494
xs.dval .= xs.dval[inds]
@@ -506,7 +506,7 @@ function EnzymeRules.forward(
506506
RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}},
507507
xs::BatchDuplicated{T, N};
508508
kwargs...
509-
) where {T, N}
509+
) where {T <: AbstractArray{<:AbstractFloat}, N}
510510
inds = sortperm(xs.val; kwargs...)
511511
xs.val .= xs.val[inds]
512512
for i in 1:N
@@ -521,13 +521,14 @@ function EnzymeRules.forward(
521521
end
522522
end
523523

524+
524525
function EnzymeRules.augmented_primal(
525526
config::EnzymeRules.ConfigWidth{1},
526527
::Const{typeof(sort!)},
527528
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
528-
xs::Duplicated;
529+
xs::Duplicated{T};
529530
kwargs...
530-
)
531+
) where {T <: AbstractArray{<:AbstractFloat}}
531532
inds = sortperm(xs.val; kwargs...)
532533
xs.val .= xs.val[inds]
533534
xs.dval .= xs.dval[inds]
@@ -549,9 +550,9 @@ function EnzymeRules.reverse(
549550
::Const{typeof(sort!)},
550551
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
551552
tape,
552-
xs::Duplicated;
553+
xs::Duplicated{T};
553554
kwargs...,
554-
)
555+
) where {T <: AbstractArray{<:AbstractFloat}}
555556
inds = tape
556557
back_inds = sortperm(inds)
557558
xs.dval .= xs.dval[back_inds]

test/internal_rules.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,19 @@ using Enzyme
44
using Enzyme.EnzymeRules
55
using Test
66

7-
@testset "Internal rules" begin
7+
struct TPair
8+
a::Float64
9+
b::Float64
10+
end
11+
12+
function sorterrfn(t, x)
13+
function lt(a, b)
14+
return a.a < b.a
15+
end
16+
return first(sortperm(t, lt=lt)) * x
17+
end
18+
19+
@testset "Sort rules" begin
820
function f1(x)
921
a = [1.0, 3.0, x]
1022
sort!(a)
@@ -27,6 +39,17 @@ using Test
2739
@test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3
2840
@test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0)
2941
@test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3
42+
43+
dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)])
44+
res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0))
45+
46+
@test res[1][2] 3
47+
@test dd.dval[1].a 0
48+
@test dd.dval[1].b 0
49+
@test dd.dval[2].a 0
50+
@test dd.dval[2].b 0
51+
@test dd.dval[3].a 0
52+
@test dd.dval[3].b 0
3053
end
3154

3255
@testset "Linear Solve" begin

0 commit comments

Comments
 (0)