@@ -4592,6 +4592,7 @@ autodiff_diff_code: type = {
4592
4592
operator=:(out this, ctx_: *autodiff_context) = {
4593
4593
ctx = ctx_;
4594
4594
}
4595
+ operator=:(out this, that) = {}
4595
4596
4596
4597
add_forward : (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
4597
4598
add_reverse_primal : (inout this, v: std::string) = { if ctx*.is_reverse() { rws_primal += v; }}
@@ -4777,6 +4778,7 @@ autodiff_expression_handler: type = {
4777
4778
4778
4779
return r;
4779
4780
}
4781
+ prepare_backprop: (this, rhs_b: std::string, lhs: std::string) -> std::string = prepare_backprop(rhs_b, lhs, lhs + ctx*.fwd_suffix, lhs + ctx*.rws_suffix);
4780
4782
4781
4783
gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string, rhs: std::string, rhs_d: std::string, rhs_b: std::string) = {
4782
4784
diff.add_forward("(lhs_d)$ = (rhs_d)$;\n");
@@ -5386,18 +5388,21 @@ autodiff_stmt_handler: type = {
5386
5388
mf: meta::function_declaration;
5387
5389
5388
5390
last_params: std::vector<meta::parameter_declaration> = ();
5391
+ overwritten: std::vector<std::string> = ();
5392
+
5393
+ overwrite_push_pop: bool = false;
5389
5394
5390
5395
operator=: (out this, ctx_: *autodiff_context, mf_: meta::function_declaration) = {
5391
5396
autodiff_handler_base = (ctx_);
5392
5397
mf = mf_;
5393
5398
}
5394
5399
5395
- handle_stmt_parameters: (inout this, params: std::vector<parameter_declaration>, leave_open: bool) = {
5400
+ handle_stmt_parameters: (inout this, params: std::vector<parameter_declaration>) -> autodiff_diff_code = {
5401
+ r : autodiff_diff_code = (ctx);
5396
5402
if params.empty() {
5397
- return;
5403
+ return r ;
5398
5404
}
5399
5405
5400
- fwd: std::string = "(";
5401
5406
for params do (param) {
5402
5407
name: std::string = param.get_declaration().name();
5403
5408
type: std::string = param.get_declaration().type();
@@ -5409,6 +5414,7 @@ autodiff_stmt_handler: type = {
5409
5414
5410
5415
init : std::string = "";
5411
5416
init_d: std::string = "";
5417
+ // TODO: Add handling for reverse expressions
5412
5418
5413
5419
if param.get_declaration().has_initializer() {
5414
5420
ad: autodiff_expression_handler = (ctx);
@@ -5421,19 +5427,16 @@ autodiff_stmt_handler: type = {
5421
5427
}
5422
5428
5423
5429
5424
- fwd += "(fwd_pass_style)$ (name)$ : (type)$(init)$, ";
5430
+ r.add_forward("(fwd_pass_style)$ (name)$ : (type)$(init)$, ");
5431
+ r.add_reverse_primal("(fwd_pass_style)$ (name)$ : (type)$(init)$, ");
5425
5432
if ada.active {
5426
- fwd += "(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$(init_d)$, ";
5433
+ r.add_forward( "(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$(init_d)$, ") ;
5427
5434
}
5428
5435
5429
5436
ctx*.add_variable_declaration(name, type, ada.active);
5430
5437
}
5431
5438
5432
- if !leave_open {
5433
- fwd += ")";
5434
- }
5435
-
5436
- diff += fwd;
5439
+ return r;
5437
5440
}
5438
5441
5439
5442
traverse: (override inout this, decl: meta::declaration) = {
@@ -5465,9 +5468,11 @@ autodiff_stmt_handler: type = {
5465
5468
if active {
5466
5469
5467
5470
fwd_ad_type : = ctx*.get_fwd_ad_type(type);
5471
+ rws_ad_type : = ctx*.get_rws_ad_type(type);
5468
5472
5469
5473
prim_init: std::string = "";
5470
5474
fwd_init : std::string = "";
5475
+ rws_init : std::string = "";
5471
5476
5472
5477
if o.has_initializer() {
5473
5478
ad: autodiff_expression_handler = (ctx);
@@ -5476,14 +5481,23 @@ autodiff_stmt_handler: type = {
5476
5481
5477
5482
prim_init = " = " + ad.primal_expr;
5478
5483
fwd_init = " = " + ad.fwd_expr;
5484
+ rws_init = " = ()"; // TODO: Proper initialization.
5485
+
5486
+ if ad.rws_expr != "()" {
5487
+ diff.add_reverse_backprop(ad.prepare_backprop(ad.rws_expr, lhs));
5488
+ }
5479
5489
5480
5490
if type == "_" && ad.fwd_expr == "()" {
5481
5491
// Special handling for auto initialization from a literal.
5482
5492
fwd_init = " = " + ctx*.get_fwd_ad_type("double") + "()";
5483
5493
}
5484
5494
}
5485
- diff += "(lhs)$(ctx*.fwd_suffix)$ : (fwd_ad_type)$(fwd_init)$;\n";
5486
- diff += "(lhs)$ : (type)$(prim_init)$;\n";
5495
+
5496
+ diff.add_forward("(lhs)$(ctx*.fwd_suffix)$ : (fwd_ad_type)$(fwd_init)$;\n");
5497
+ diff.add_forward("(lhs)$ : (type)$(prim_init)$;\n");
5498
+
5499
+ diff.add_reverse_primal("(lhs)$(ctx*.rws_suffix)$ : (rws_ad_type)$(rws_init)$;\n");
5500
+ diff.add_reverse_primal("(lhs)$ : (type)$(prim_init)$;\n");
5487
5501
}
5488
5502
else {
5489
5503
diff += "(lhs)$: (type)$";
@@ -5515,9 +5529,37 @@ autodiff_stmt_handler: type = {
5515
5529
5516
5530
5517
5531
traverse: (override inout this, stmt: meta::compound_statement) = {
5518
- diff += "{\n";
5519
- base::traverse(stmt);
5520
- diff += "}\n";
5532
+ ad : autodiff_stmt_handler = (ctx, mf);
5533
+ ad_push_pop: autodiff_stmt_handler = (ctx, mf);
5534
+ ad_push_pop.overwrite_push_pop = true;
5535
+
5536
+ diff.add_forward("{\n");
5537
+ diff.add_reverse_primal("{\n");
5538
+ diff.add_reverse_backprop("}\n");
5539
+
5540
+ for stmt.get_statements() do (cur) {
5541
+ ad.pre_traverse(cur);
5542
+ ad_push_pop.pre_traverse(cur);
5543
+ }
5544
+
5545
+ for ad.overwritten do (cur) {
5546
+ r := ctx*.lookup_variable_declaration(cur);
5547
+ diff.add_reverse_primal("cpp2::ad_stack::push<(r.decl)$>((cur)$);");
5548
+ }
5549
+
5550
+ diff.add_forward(ad.diff.fwd);
5551
+ diff.add_reverse_primal(ad.diff.rws_primal);
5552
+ diff.add_reverse_backprop(ad_push_pop.diff.rws_backprop);
5553
+ diff.add_reverse_backprop(ad_push_pop.diff.rws_primal);
5554
+
5555
+ for ad.overwritten do (cur) {
5556
+ r := ctx*.lookup_variable_declaration(cur);
5557
+ diff.add_reverse_backprop("(cur)$ = cpp2::ad_stack::pop<(r.decl)$>();");
5558
+ }
5559
+
5560
+ diff.add_forward("}\n");
5561
+ diff.add_reverse_primal("}\n");
5562
+ diff.add_reverse_backprop("{\n");
5521
5563
}
5522
5564
5523
5565
@@ -5537,13 +5579,32 @@ autodiff_stmt_handler: type = {
5537
5579
}
5538
5580
}
5539
5581
5582
+ reverse_next: (this, expr: std::string) -> std::string = {
5583
+ if expr.contains("+=") {
5584
+ return string_util::replace_all(expr, "+=", "-=");
5585
+ }
5586
+ else if expr.contains("-=") {
5587
+ return string_util::replace_all(expr, "-=", "+=");
5588
+ }
5589
+
5590
+ mf.error("AD: Do not know how to reverse: (expr)$");
5591
+
5592
+ return "Error";
5593
+
5594
+ }
5595
+
5540
5596
5541
5597
traverse: (override inout this, stmt: meta::iteration_statement) = {
5542
- if !last_params.empty() {
5543
- handle_stmt_parameters(last_params, stmt.is_for());
5598
+ diff_params := handle_stmt_parameters(last_params);
5599
+
5600
+ if ctx*.is_reverse() && (stmt.is_while() || stmt.is_do()) {
5601
+ stmt.error("AD: Alpha limitiation now reverse mode for while or do while.");
5544
5602
}
5545
5603
5546
5604
if stmt.is_while() {
5605
+ if !last_params.empty() {
5606
+ diff.add_forward("(" + diff_params.fwd + ")");
5607
+ }
5547
5608
// TODO: Assumption is here that nothing is in the condition
5548
5609
diff += "while (stmt.get_do_while_condition().to_string())$ ";
5549
5610
if stmt.has_next() {
@@ -5554,6 +5615,10 @@ autodiff_stmt_handler: type = {
5554
5615
pre_traverse(stmt.get_do_while_body());
5555
5616
}
5556
5617
else if stmt.is_do() {
5618
+ if !last_params.empty() {
5619
+ diff.add_forward("(" + diff_params.fwd + ")");
5620
+ }
5621
+
5557
5622
// TODO: Assumption is here that nothing is in the condition
5558
5623
diff += "do ";
5559
5624
pre_traverse(stmt.get_do_while_body());
@@ -5574,23 +5639,55 @@ autodiff_stmt_handler: type = {
5574
5639
param := stmt.get_for_parameter();
5575
5640
param_style := to_string_view(param.get_passing_style());
5576
5641
param_decl := param.get_declaration();
5577
- if last_params.empty() {
5578
- diff += "("; // Open statment parameter scope. If the loop has parameters, they are alrady handled and the brace is left open.
5642
+
5643
+ rws : std::string = "(";
5644
+ rws_restore: std::string = "";
5645
+ diff.add_forward("("); // Open statment parameter scope. If the loop has parameters, they are alrady handled and the brace is left open.
5646
+ diff.add_reverse_primal("{\n");
5647
+ if !last_params.empty() {
5648
+ for last_params do (cur) {
5649
+ if cur.get_declaration().has_initializer() {
5650
+ // TODO: Handle no type and no initializer. Handle passing style.
5651
+ diff.add_reverse_primal("(cur.get_declaration().name())$: (cur.get_declaration().type())$ = (cur.get_declaration().get_initializer().to_string())$;\n");
5652
+ rws_restore += "cpp2::ad_stack::push<(cur.get_declaration().type())$>((cur.get_declaration().name())$);\n";
5653
+ rws += "(to_string_view(cur.get_passing_style()))$ (cur.get_declaration().name())$: (cur.get_declaration().type())$ = cpp2::ad_stack::pop<(cur.get_declaration().type())$>(), ";
5654
+ }
5655
+ }
5656
+ diff.add_forward(diff_params.fwd);
5579
5657
}
5580
- diff += "copy (param_decl.name())$_d_iter := (range)$(ctx*.fwd_suffix)$.begin())\n";
5581
- diff += "for (range)$ next (";
5658
+ diff.add_forward("copy (param_decl.name())$(ctx*.fwd_suffix)$_iter := (range)$(ctx*.fwd_suffix)$.begin())\n");
5659
+ diff.add_forward("for (range)$ next (");
5660
+
5661
+ rws += "copy (param_decl.name())$(ctx*.rws_suffix)$_iter := (range)$(ctx*.rws_suffix)$.rbegin())\n";
5662
+ rws += "for std::ranges::reverse_view((range)$) next (";
5663
+ diff.add_reverse_primal("for (range)$ next (");
5582
5664
if stmt.has_next() {
5583
5665
// TODO: Assumption is here that nothing is in the next expression
5584
- diff += "(stmt.get_next_expression().to_string())$, ";
5666
+ diff.add_forward("(stmt.get_next_expression().to_string())$, ");
5667
+ diff.add_reverse_primal("(stmt.get_next_expression().to_string())$, ");
5668
+ rws += "(reverse_next(stmt.get_next_expression().to_string()))$, ";
5585
5669
}
5586
- diff += "(param_decl.name())$_d_iter++";
5587
- diff += ") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n";
5588
- diff += "((param_style)$ (param_decl.name())$(ctx*.fwd_suffix)$: (param_decl.type())$ = (param_decl.name())$_d_iter*)";
5670
+ diff.add_forward("(param_decl.name())$(ctx*.fwd_suffix)$_iter++");
5671
+ diff.add_forward(") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n");
5672
+ rws += "(param_decl.name())$(ctx*.rws_suffix)$_iter++";
5673
+ rws += ") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n";
5674
+ rws += "(inout (param_decl.name())$(ctx*.rws_suffix)$ := (param_decl.name())$(ctx*.rws_suffix)$_iter*)\n";
5675
+
5676
+ diff.add_reverse_primal(") do ((param_style)$ (param_decl.name())$: (param_decl.type())$)");
5677
+ diff.add_forward("((param_style)$ (param_decl.name())$(ctx*.fwd_suffix)$: (param_decl.type())$ = (param_decl.name())$(ctx*.fwd_suffix)$_iter*)");
5589
5678
5590
5679
ctx*.add_variable_declaration("(param_decl.name())$", "(param_decl.type())$", true); // TODO: Handle loop/compound context variable declarations.
5680
+ diff.add_reverse_backprop("}\n");
5591
5681
5592
5682
pre_traverse(stmt.get_for_body());
5593
- diff += "}\n";
5683
+ diff.add_forward("}\n");
5684
+
5685
+ if stmt.has_next() {
5686
+ diff.add_reverse_primal("(reverse_next(stmt.get_next_expression().to_string()))$;\n");
5687
+ }
5688
+ diff.add_reverse_primal(rws_restore);
5689
+ diff.add_reverse_primal("}\n");
5690
+ diff.add_reverse_backprop(rws);
5594
5691
}
5595
5692
}
5596
5693
@@ -5626,8 +5723,34 @@ autodiff_stmt_handler: type = {
5626
5723
5627
5724
h: autodiff_expression_handler = (ctx);
5628
5725
h.pre_traverse(assignment_terms[1].get_term());
5629
- h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
5630
- append(h);
5726
+
5727
+ is_overwrite := h.primal_expr.contains(h_lhs.primal_expr);
5728
+ if overwrite_push_pop && is_overwrite {
5729
+ r := ctx*.lookup_variable_declaration(h_lhs.primal_expr);
5730
+ diff.add_reverse_primal("cpp2::ad_stack::push<(r.decl)$>((h_lhs.primal_expr)$);");
5731
+ }
5732
+
5733
+ if is_overwrite && ctx*.is_reverse() {
5734
+ t_b := ctx*.gen_temporary() + ctx*.rws_suffix;
5735
+ h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, t_b);
5736
+ append(h);
5737
+ diff.add_reverse_backprop("(h_lhs.rws_expr)$ = 0.0;\n");
5738
+ diff.add_reverse_backprop("(t_b)$ := (h_lhs.rws_expr)$;\n");
5739
+ }
5740
+ else {
5741
+ h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
5742
+ append(h);
5743
+ }
5744
+
5745
+ if overwrite_push_pop && is_overwrite {
5746
+ r := ctx*.lookup_variable_declaration(h_lhs.primal_expr);
5747
+ diff.add_reverse_backprop("(h_lhs.primal_expr)$ = cpp2::ad_stack::pop<(r.decl)$>();");
5748
+ }
5749
+
5750
+ // Simple overwrite check
5751
+ if is_overwrite {
5752
+ overwritten.push_back(h_lhs.primal_expr);
5753
+ }
5631
5754
}
5632
5755
else {
5633
5756
diff.add_forward(binexpr.to_string() + ";\n");
@@ -6046,6 +6169,9 @@ autodiff: (inout t: meta::type_declaration) =
6046
6169
if 1 != order {
6047
6170
t.add_runtime_support_include( "cpp2taylor.h" );
6048
6171
}
6172
+ if reverse {
6173
+ t.add_runtime_support_include( "cpp2ad_stack.h" );
6174
+ }
6049
6175
6050
6176
ad_ctx.finish();
6051
6177
0 commit comments