Skip to content

Commit 770b064

Browse files
jgreener64wsmoses
andauthored
sort! rules (EnzymeAD#1000)
* sort! rules * Sort in augmented primal * sort! rule tests * Batched sort! rule * Move A \ B rule test * Fix after rebase * Add missing end --------- Co-authored-by: William Moses <gh@wsmoses.com>
1 parent 10d380b commit 770b064

File tree

3 files changed

+162
-54
lines changed

3 files changed

+162
-54
lines changed

src/internal_rules.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,78 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty
482482
return (nothing, nothing)
483483
end
484484
end
485+
486+
function EnzymeRules.forward(
487+
::Const{typeof(sort!)},
488+
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
489+
xs::Duplicated;
490+
kwargs...
491+
)
492+
inds = sortperm(xs.val; kwargs...)
493+
xs.val .= xs.val[inds]
494+
xs.dval .= xs.dval[inds]
495+
if RT <: Const
496+
return xs.val
497+
elseif RT <: DuplicatedNoNeed
498+
return xs.dval
499+
else
500+
return xs
501+
end
502+
end
503+
504+
function EnzymeRules.forward(
505+
::Const{typeof(sort!)},
506+
RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}},
507+
xs::BatchDuplicated{T, N};
508+
kwargs...
509+
) where {T, N}
510+
inds = sortperm(xs.val; kwargs...)
511+
xs.val .= xs.val[inds]
512+
for i in 1:N
513+
xs.dval[i] .= xs.dval[i][inds]
514+
end
515+
if RT <: Const
516+
return xs.val
517+
elseif RT <: BatchDuplicatedNoNeed
518+
return xs.dval
519+
else
520+
return xs
521+
end
522+
end
523+
524+
function EnzymeRules.augmented_primal(
525+
config::EnzymeRules.ConfigWidth{1},
526+
::Const{typeof(sort!)},
527+
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
528+
xs::Duplicated;
529+
kwargs...
530+
)
531+
inds = sortperm(xs.val; kwargs...)
532+
xs.val .= xs.val[inds]
533+
xs.dval .= xs.dval[inds]
534+
if EnzymeRules.needs_primal(config)
535+
primal = xs.val
536+
else
537+
primal = nothing
538+
end
539+
if RT <: Const
540+
shadow = nothing
541+
else
542+
shadow = xs.dval
543+
end
544+
return EnzymeRules.AugmentedReturn(primal, shadow, inds)
545+
end
546+
547+
function EnzymeRules.reverse(
548+
config::EnzymeRules.ConfigWidth{1},
549+
::Const{typeof(sort!)},
550+
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
551+
tape,
552+
xs::Duplicated;
553+
kwargs...,
554+
)
555+
inds = tape
556+
back_inds = sortperm(inds)
557+
xs.dval .= xs.dval[back_inds]
558+
return (nothing,)
559+
end

test/internal_rules.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
module InternalRules
2+
3+
using Enzyme
4+
using Enzyme.EnzymeRules
5+
using Test
6+
7+
@testset "Internal rules" begin
8+
function f1(x)
9+
a = [1.0, 3.0, x]
10+
sort!(a)
11+
return a[2]
12+
end
13+
14+
@test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1
15+
@test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0)
16+
@test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1
17+
@test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0
18+
@test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0)
19+
@test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0
20+
21+
function f2(x)
22+
a = [1.0, -3.0, -x, -2x, x]
23+
sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y))
24+
return sum(a .* [1, 2, 3, 4, 5])
25+
end
26+
27+
@test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3
28+
@test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0)
29+
@test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3
30+
end
31+
32+
@testset "Linear Solve" begin
33+
A = Float64[2 3; 5 7]
34+
dA = zero(A)
35+
b = Float64[11, 13]
36+
db = zero(b)
37+
38+
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)})
39+
40+
tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db))
41+
42+
dy = Float64[17, 19]
43+
copyto!(shadow, dy)
44+
45+
pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape)
46+
47+
z = transpose(A) \ dy
48+
49+
y = A \ b
50+
@test dA (-z * transpose(y))
51+
@test db z
52+
53+
db = zero(b)
54+
55+
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})
56+
57+
tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))
58+
59+
dy = Float64[17, 19]
60+
copyto!(shadow, dy)
61+
62+
pullback(Const(\), Const(A), Duplicated(b, db), tape)
63+
64+
z = transpose(A) \ dy
65+
66+
y = A \ b
67+
@test db z
68+
69+
dA = zero(A)
70+
71+
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})
72+
73+
tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))
74+
75+
dy = Float64[17, 19]
76+
copyto!(shadow, dy)
77+
78+
pullback(Const(\), Duplicated(A, dA), Const(b), tape)
79+
80+
z = transpose(A) \ dy
81+
82+
y = A \ b
83+
@test dA (-z * transpose(y))
84+
end
85+
86+
end # InternalRules

test/runtests.jl

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ include("typetree.jl")
7676
include("rrules.jl")
7777
include("kwrules.jl")
7878
include("kwrrules.jl")
79+
include("internal_rules.jl")
7980
@static if VERSION v"1.9-"
8081
# XXX invalidation does not work on Julia 1.8
8182
include("ruleinvalidation.jl")
@@ -2615,60 +2616,6 @@ end
26152616
end
26162617
end
26172618

2618-
@testset "Linear Solve" begin
2619-
A = Float64[2 3; 5 7]
2620-
dA = zero(A)
2621-
b = Float64[11, 13]
2622-
db = zero(b)
2623-
2624-
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)})
2625-
2626-
tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db))
2627-
2628-
dy = Float64[17, 19]
2629-
copyto!(shadow, dy)
2630-
2631-
pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape)
2632-
2633-
z = transpose(A) \ dy
2634-
2635-
y = A \ b
2636-
@test dA (-z * transpose(y))
2637-
@test db z
2638-
2639-
db = zero(b)
2640-
2641-
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})
2642-
2643-
tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))
2644-
2645-
dy = Float64[17, 19]
2646-
copyto!(shadow, dy)
2647-
2648-
pullback(Const(\), Const(A), Duplicated(b, db), tape)
2649-
2650-
z = transpose(A) \ dy
2651-
2652-
y = A \ b
2653-
@test db z
2654-
2655-
dA = zero(A)
2656-
2657-
forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})
2658-
2659-
tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))
2660-
2661-
dy = Float64[17, 19]
2662-
copyto!(shadow, dy)
2663-
2664-
pullback(Const(\), Duplicated(A, dA), Const(b), tape)
2665-
2666-
z = transpose(A) \ dy
2667-
2668-
y = A \ b
2669-
@test dA (-z * transpose(y))
2670-
end
2671-
26722619
@static if VERSION >= v"1.7-"
26732620
@testset "hvcat_fill" begin
26742621
ar = Matrix{Float64}(undef, 2, 3)

0 commit comments

Comments
 (0)