File tree Expand file tree Collapse file tree 2 files changed +30
-0
lines changed Expand file tree Collapse file tree 2 files changed +30
-0
lines changed Original file line number Diff line number Diff line change 99
99
end
100
100
end
101
101
102
+ Zygote. @adjoint function Zygote. literal_getproperty (A:: RecursiveArrayTools.AbstractVectorOfArray , :: Val{:u} )
103
+ function literal_AbstractVofA_u_adjoint (d)
104
+ dA = vofa_u_adjoint (d, A)
105
+ (dA, nothing )
106
+ end
107
+ A. u, literal_AbstractVofA_u_adjoint
108
+ end
109
+
110
+ function vofa_u_adjoint (d, A:: RecursiveArrayTools.AbstractVectorOfArray )
111
+ m = map (enumerate (d)) do (idx, d_i)
112
+ isnothing (d_i) && return zero (A. u[idx])
113
+ d_i
114
+ end
115
+ VectorOfArray (m)
116
+ end
117
+
118
+ function vofa_u_adjoint (d, A:: RecursiveArrayTools.AbstractDiffEqArray )
119
+ m = map (enumerate (d)) do (idx, d_i)
120
+ isnothing (d_i) && return zero (A. u[idx])
121
+ d_i
122
+ end
123
+ DiffEqArray (m, A. t)
124
+ end
125
+
102
126
@adjoint function literal_getproperty (A:: ArrayPartition , :: Val{:x} )
103
127
function literal_ArrayPartition_x_adjoint (d)
104
128
(ArrayPartition ((isnothing (d[i]) ? zero (A. x[i]) : d[i] for i in 1 : length (d)). .. ),)
Original file line number Diff line number Diff line change @@ -92,3 +92,9 @@ loss(x)
92
92
VectorOfArray ([collect ((3 i): (3 i + 3 )) for i in 1 : 5 ])
93
93
@test Zygote. gradient (loss10, x)[1 ] == ForwardDiff. gradient (loss10, x)
94
94
@test Zygote. gradient (loss11, x)[1 ] == ForwardDiff. gradient (loss11, x)
95
+
96
+ voa = RecursiveArrayTools. VectorOfArray (fill (rand (3 ), 3 ))
97
+ voa_gs, = Zygote. gradient (voa) do x
98
+ sum (sum .(x. u))
99
+ end
100
+ @test voa_gs isa RecursiveArrayTools. VectorOfArray
You can’t perform that action at this time.
0 commit comments