@@ -50,16 +50,32 @@ pub enum DiffActivity {
5050 /// with it.
5151 Dual ,
5252 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53+ /// with it. It expects the shadow argument to be `width` times larger than the original
54+ /// input/output.
55+ Dualv ,
56+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
5357 /// with it. Drop the code which updates the original input/output for maximum performance.
5458 DualOnly ,
59+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60+ /// with it. Drop the code which updates the original input/output for maximum performance.
61+ /// It expects the shadow argument to be `width` times larger than the original input/output.
62+ DualvOnly ,
5563 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
5664 Duplicated ,
5765 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
5866 /// Drop the code which updates the original input for maximum performance.
5967 DuplicatedOnly ,
6068 /// All Integers must be Const, but these are used to mark the integer which represents the
6169 /// length of a slice/vec. This is used for safety checks on slices.
62- FakeActivitySize ,
70+ /// The integer (if given) specifies the size of the slice element in bytes.
71+ FakeActivitySize ( Option < u32 > ) ,
72+ }
73+
74+ impl DiffActivity {
75+ pub fn is_dual_or_const ( & self ) -> bool {
76+ use DiffActivity :: * ;
77+ matches ! ( self , |Dual | DualOnly | Dualv | DualvOnly | Const )
78+ }
6379}
6480/// We generate one of these structs for each `#[autodiff(...)]` attribute.
6581#[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
@@ -131,11 +147,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
131147 match mode {
132148 DiffMode :: Error => false ,
133149 DiffMode :: Source => false ,
134- DiffMode :: Forward => {
135- activity == DiffActivity :: Dual
136- || activity == DiffActivity :: DualOnly
137- || activity == DiffActivity :: Const
138- }
150+ DiffMode :: Forward => activity. is_dual_or_const ( ) ,
139151 DiffMode :: Reverse => {
140152 activity == DiffActivity :: Const
141153 || activity == DiffActivity :: Active
@@ -153,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
153165pub fn valid_ty_for_activity ( ty : & P < Ty > , activity : DiffActivity ) -> bool {
154166 use DiffActivity :: * ;
155167 // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
156- if matches ! ( activity, Const ) {
157- return true ;
158- }
159- if matches ! ( activity, Dual | DualOnly ) {
168+ // Dual variants also support all types.
169+ if activity. is_dual_or_const ( ) {
160170 return true ;
161171 }
162172 // FIXME(ZuseZ4) We should make this more robust to also
@@ -172,9 +182,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
172182 return match mode {
173183 DiffMode :: Error => false ,
174184 DiffMode :: Source => false ,
175- DiffMode :: Forward => {
176- matches ! ( activity, Dual | DualOnly | Const )
177- }
185+ DiffMode :: Forward => activity. is_dual_or_const ( ) ,
178186 DiffMode :: Reverse => {
179187 matches ! ( activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const )
180188 }
@@ -189,10 +197,12 @@ impl Display for DiffActivity {
189197 DiffActivity :: Active => write ! ( f, "Active" ) ,
190198 DiffActivity :: ActiveOnly => write ! ( f, "ActiveOnly" ) ,
191199 DiffActivity :: Dual => write ! ( f, "Dual" ) ,
200+ DiffActivity :: Dualv => write ! ( f, "Dualv" ) ,
192201 DiffActivity :: DualOnly => write ! ( f, "DualOnly" ) ,
202+ DiffActivity :: DualvOnly => write ! ( f, "DualvOnly" ) ,
193203 DiffActivity :: Duplicated => write ! ( f, "Duplicated" ) ,
194204 DiffActivity :: DuplicatedOnly => write ! ( f, "DuplicatedOnly" ) ,
195- DiffActivity :: FakeActivitySize => write ! ( f, "FakeActivitySize" ) ,
205+ DiffActivity :: FakeActivitySize ( s ) => write ! ( f, "FakeActivitySize({:?})" , s ) ,
196206 }
197207 }
198208}
@@ -220,7 +230,9 @@ impl FromStr for DiffActivity {
220230 "ActiveOnly" => Ok ( DiffActivity :: ActiveOnly ) ,
221231 "Const" => Ok ( DiffActivity :: Const ) ,
222232 "Dual" => Ok ( DiffActivity :: Dual ) ,
233+ "Dualv" => Ok ( DiffActivity :: Dualv ) ,
223234 "DualOnly" => Ok ( DiffActivity :: DualOnly ) ,
235+ "DualvOnly" => Ok ( DiffActivity :: DualvOnly ) ,
224236 "Duplicated" => Ok ( DiffActivity :: Duplicated ) ,
225237 "DuplicatedOnly" => Ok ( DiffActivity :: DuplicatedOnly ) ,
226238 _ => Err ( ( ) ) ,
0 commit comments