Skip to content

Commit e83e561

Browse files
committed
Added handling for differentiation of symbols outside of metafunction type declaration.
1 parent 92d2407 commit e83e561

File tree

7 files changed

+1513
-1378
lines changed

7 files changed

+1513
-1378
lines changed

regression-tests/pure2-autodiff.cpp2

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
ad_name: namespace = {
22

3+
func_outer: (x: double, y: double) -> (ret: double) = {
4+
ret = x + y;
5+
}
6+
37
ad_test: @autodiff @print type = {
48

59
add_1: (x: double, y: double) -> (r: double) = {
@@ -58,6 +62,10 @@ ad_test: @autodiff @print type = {
5862
r = x * func(x, y);
5963
}
6064

65+
func_outer_call: (x: double, y: double) -> (r: double) = {
66+
r = x * func_outer(x, y);
67+
}
68+
6169
sin_call: (x: double, y: double) -> (r: double) = {
6270
r = sin(x - y);
6371
}
@@ -184,6 +192,7 @@ main: () = {
184192
write_output("x * (x + y)", x, x_d, y, y_d, ad_name::ad_test::mul_add_d(x, x_d, y, y_d));
185193
write_output("x + x * y", x, x_d, y, y_d, ad_name::ad_test::add_mul_d(x, x_d, y, y_d));
186194
write_output("x * func(x, y)", x, x_d, y, y_d, ad_name::ad_test::func_call_d(x, x_d, y, y_d));
195+
write_output("x * func_outer(x, y)", x, x_d, y, y_d, ad_name::ad_test::func_outer_call_d(x, x_d, y, y_d));
187196
write_output("sin(x - y)", x, x_d, y, y_d, ad_name::ad_test::sin_call_d(x, x_d, y, y_d));
188197
write_output("if branch", x, x_d, y, y_d, ad_name::ad_test::if_branch_d(x, x_d, y, y_d));
189198
write_output("if else branch", x, x_d, y, y_d, ad_name::ad_test::if_else_branch_d(x, x_d, y, y_d));

regression-tests/test-results/gcc-13-c++2b/pure2-autodiff.cpp.execution

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ diff(x * y / x) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000)
1111
diff(x * (x + y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
1212
diff(x + x * y) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 8.000000, r_d = 8.000000)
1313
diff(x * func(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
14+
diff(x * func_outer(x, y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 10.000000, r_d = 11.000000)
1415
diff(sin(x - y)) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = -0.841471, r_d = -0.540302)
1516
diff(if branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)
1617
diff(if else branch) at (x = 2.000000, x_d = 1.000000, y = 3.000000, y_d = 2.000000) = (r = 2.000000, r_d = 1.000000)

0 commit comments

Comments
 (0)