Skip to content

Commit f8ebf8d

Browse files
Merge pull request #478 from DhairyaLGandhi/dg/literal_voa
Add `literal_getproperty` disptach for `VectorOfArray`
2 parents a7db825 + 43adec0 commit f8ebf8d

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,30 @@ end
9999
end
100100
end
101101

102+
Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u})
103+
function literal_AbstractVofA_u_adjoint(d)
104+
dA = vofa_u_adjoint(d, A)
105+
(dA, nothing)
106+
end
107+
A.u, literal_AbstractVofA_u_adjoint
108+
end
109+
110+
function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractVectorOfArray)
111+
m = map(enumerate(d)) do (idx, d_i)
112+
isnothing(d_i) && return zero(A.u[idx])
113+
d_i
114+
end
115+
VectorOfArray(m)
116+
end
117+
118+
function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractDiffEqArray)
119+
m = map(enumerate(d)) do (idx, d_i)
120+
isnothing(d_i) && return zero(A.u[idx])
121+
d_i
122+
end
123+
DiffEqArray(m, A.t)
124+
end
125+
102126
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
103127
function literal_ArrayPartition_x_adjoint(d)
104128
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)

test/adjoints.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,9 @@ loss(x)
9292
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
9393
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
9494
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)
95+
96+
voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3))
97+
voa_gs, = Zygote.gradient(voa) do x
98+
sum(sum.(x.u))
99+
end
100+
@test voa_gs isa RecursiveArrayTools.VectorOfArray

0 commit comments

Comments
 (0)