Skip to content

Commit c8fdb06

Browse files
committed
Reverse handling of for loops.
1 parent a05accd commit c8fdb06

File tree

1 file changed

+154
-28
lines changed

1 file changed

+154
-28
lines changed

source/reflect.h2

Lines changed: 154 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4592,6 +4592,7 @@ autodiff_diff_code: type = {
45924592
operator=:(out this, ctx_: *autodiff_context) = {
45934593
ctx = ctx_;
45944594
}
4595+
operator=:(out this, that) = {}
45954596

45964597
add_forward : (inout this, v: std::string) = { if ctx*.is_forward() { fwd += v; }}
45974598
add_reverse_primal : (inout this, v: std::string) = { if ctx*.is_reverse() { rws_primal += v; }}
@@ -4777,6 +4778,7 @@ autodiff_expression_handler: type = {
47774778

47784779
return r;
47794780
}
4781+
prepare_backprop: (this, rhs_b: std::string, lhs: std::string) -> std::string = prepare_backprop(rhs_b, lhs, lhs + ctx*.fwd_suffix, lhs + ctx*.rws_suffix);
47804782

47814783
gen_assignment: (inout this, lhs: std::string, lhs_d: std::string, lhs_b: std::string, rhs: std::string, rhs_d: std::string, rhs_b: std::string) = {
47824784
diff.add_forward("(lhs_d)$ = (rhs_d)$;\n");
@@ -5386,18 +5388,21 @@ autodiff_stmt_handler: type = {
53865388
mf: meta::function_declaration;
53875389

53885390
last_params: std::vector<meta::parameter_declaration> = ();
5391+
overwritten: std::vector<std::string> = ();
5392+
5393+
overwrite_push_pop: bool = false;
53895394

53905395
operator=: (out this, ctx_: *autodiff_context, mf_: meta::function_declaration) = {
53915396
autodiff_handler_base = (ctx_);
53925397
mf = mf_;
53935398
}
53945399

5395-
handle_stmt_parameters: (inout this, params: std::vector<parameter_declaration>, leave_open: bool) = {
5400+
handle_stmt_parameters: (inout this, params: std::vector<parameter_declaration>) -> autodiff_diff_code = {
5401+
r : autodiff_diff_code = (ctx);
53965402
if params.empty() {
5397-
return;
5403+
return r;
53985404
}
53995405

5400-
fwd: std::string = "(";
54015406
for params do (param) {
54025407
name: std::string = param.get_declaration().name();
54035408
type: std::string = param.get_declaration().type();
@@ -5409,6 +5414,7 @@ autodiff_stmt_handler: type = {
54095414

54105415
init : std::string = "";
54115416
init_d: std::string = "";
5417+
// TODO: Add handling for reverse expressions
54125418

54135419
if param.get_declaration().has_initializer() {
54145420
ad: autodiff_expression_handler = (ctx);
@@ -5421,19 +5427,16 @@ autodiff_stmt_handler: type = {
54215427
}
54225428

54235429

5424-
fwd += "(fwd_pass_style)$ (name)$ : (type)$(init)$, ";
5430+
r.add_forward("(fwd_pass_style)$ (name)$ : (type)$(init)$, ");
5431+
r.add_reverse_primal("(fwd_pass_style)$ (name)$ : (type)$(init)$, ");
54255432
if ada.active {
5426-
fwd += "(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$(init_d)$, ";
5433+
r.add_forward("(fwd_pass_style)$ (name)$(ctx*.fwd_suffix)$ : (ctx*.get_fwd_ad_type(type))$(init_d)$, ");
54275434
}
54285435

54295436
ctx*.add_variable_declaration(name, type, ada.active);
54305437
}
54315438

5432-
if !leave_open {
5433-
fwd += ")";
5434-
}
5435-
5436-
diff += fwd;
5439+
return r;
54375440
}
54385441

54395442
traverse: (override inout this, decl: meta::declaration) = {
@@ -5465,9 +5468,11 @@ autodiff_stmt_handler: type = {
54655468
if active {
54665469

54675470
fwd_ad_type : = ctx*.get_fwd_ad_type(type);
5471+
rws_ad_type : = ctx*.get_rws_ad_type(type);
54685472

54695473
prim_init: std::string = "";
54705474
fwd_init : std::string = "";
5475+
rws_init : std::string = "";
54715476

54725477
if o.has_initializer() {
54735478
ad: autodiff_expression_handler = (ctx);
@@ -5476,14 +5481,23 @@ autodiff_stmt_handler: type = {
54765481

54775482
prim_init = " = " + ad.primal_expr;
54785483
fwd_init = " = " + ad.fwd_expr;
5484+
rws_init = " = ()"; // TODO: Proper initialization.
5485+
5486+
if ad.rws_expr != "()" {
5487+
diff.add_reverse_backprop(ad.prepare_backprop(ad.rws_expr, lhs));
5488+
}
54795489

54805490
if type == "_" && ad.fwd_expr == "()" {
54815491
// Special handling for auto initialization from a literal.
54825492
fwd_init = " = " + ctx*.get_fwd_ad_type("double") + "()";
54835493
}
54845494
}
5485-
diff += "(lhs)$(ctx*.fwd_suffix)$ : (fwd_ad_type)$(fwd_init)$;\n";
5486-
diff += "(lhs)$ : (type)$(prim_init)$;\n";
5495+
5496+
diff.add_forward("(lhs)$(ctx*.fwd_suffix)$ : (fwd_ad_type)$(fwd_init)$;\n");
5497+
diff.add_forward("(lhs)$ : (type)$(prim_init)$;\n");
5498+
5499+
diff.add_reverse_primal("(lhs)$(ctx*.rws_suffix)$ : (rws_ad_type)$(rws_init)$;\n");
5500+
diff.add_reverse_primal("(lhs)$ : (type)$(prim_init)$;\n");
54875501
}
54885502
else {
54895503
diff += "(lhs)$: (type)$";
@@ -5515,9 +5529,37 @@ autodiff_stmt_handler: type = {
55155529

55165530

55175531
traverse: (override inout this, stmt: meta::compound_statement) = {
5518-
diff += "{\n";
5519-
base::traverse(stmt);
5520-
diff += "}\n";
5532+
ad : autodiff_stmt_handler = (ctx, mf);
5533+
ad_push_pop: autodiff_stmt_handler = (ctx, mf);
5534+
ad_push_pop.overwrite_push_pop = true;
5535+
5536+
diff.add_forward("{\n");
5537+
diff.add_reverse_primal("{\n");
5538+
diff.add_reverse_backprop("}\n");
5539+
5540+
for stmt.get_statements() do (cur) {
5541+
ad.pre_traverse(cur);
5542+
ad_push_pop.pre_traverse(cur);
5543+
}
5544+
5545+
for ad.overwritten do (cur) {
5546+
r := ctx*.lookup_variable_declaration(cur);
5547+
diff.add_reverse_primal("cpp2::ad_stack::push<(r.decl)$>((cur)$);");
5548+
}
5549+
5550+
diff.add_forward(ad.diff.fwd);
5551+
diff.add_reverse_primal(ad.diff.rws_primal);
5552+
diff.add_reverse_backprop(ad_push_pop.diff.rws_backprop);
5553+
diff.add_reverse_backprop(ad_push_pop.diff.rws_primal);
5554+
5555+
for ad.overwritten do (cur) {
5556+
r := ctx*.lookup_variable_declaration(cur);
5557+
diff.add_reverse_backprop("(cur)$ = cpp2::ad_stack::pop<(r.decl)$>();");
5558+
}
5559+
5560+
diff.add_forward("}\n");
5561+
diff.add_reverse_primal("}\n");
5562+
diff.add_reverse_backprop("{\n");
55215563
}
55225564

55235565

@@ -5537,13 +5579,32 @@ autodiff_stmt_handler: type = {
55375579
}
55385580
}
55395581

5582+
reverse_next: (this, expr: std::string) -> std::string = {
5583+
if expr.contains("+=") {
5584+
return string_util::replace_all(expr, "+=", "-=");
5585+
}
5586+
else if expr.contains("-=") {
5587+
return string_util::replace_all(expr, "-=", "+=");
5588+
}
5589+
5590+
mf.error("AD: Do not know how to reverse: (expr)$");
5591+
5592+
return "Error";
5593+
5594+
}
5595+
55405596

55415597
traverse: (override inout this, stmt: meta::iteration_statement) = {
5542-
if !last_params.empty() {
5543-
handle_stmt_parameters(last_params, stmt.is_for());
5598+
diff_params := handle_stmt_parameters(last_params);
5599+
5600+
if ctx*.is_reverse() && (stmt.is_while() || stmt.is_do()) {
5601+
stmt.error("AD: Alpha limitiation now reverse mode for while or do while.");
55445602
}
55455603

55465604
if stmt.is_while() {
5605+
if !last_params.empty() {
5606+
diff.add_forward("(" + diff_params.fwd + ")");
5607+
}
55475608
// TODO: Assumption is here that nothing is in the condition
55485609
diff += "while (stmt.get_do_while_condition().to_string())$ ";
55495610
if stmt.has_next() {
@@ -5554,6 +5615,10 @@ autodiff_stmt_handler: type = {
55545615
pre_traverse(stmt.get_do_while_body());
55555616
}
55565617
else if stmt.is_do() {
5618+
if !last_params.empty() {
5619+
diff.add_forward("(" + diff_params.fwd + ")");
5620+
}
5621+
55575622
// TODO: Assumption is here that nothing is in the condition
55585623
diff += "do ";
55595624
pre_traverse(stmt.get_do_while_body());
@@ -5574,23 +5639,55 @@ autodiff_stmt_handler: type = {
55745639
param := stmt.get_for_parameter();
55755640
param_style := to_string_view(param.get_passing_style());
55765641
param_decl := param.get_declaration();
5577-
if last_params.empty() {
5578-
diff += "("; // Open statment parameter scope. If the loop has parameters, they are alrady handled and the brace is left open.
5642+
5643+
rws : std::string = "(";
5644+
rws_restore: std::string = "";
5645+
diff.add_forward("("); // Open statment parameter scope. If the loop has parameters, they are alrady handled and the brace is left open.
5646+
diff.add_reverse_primal("{\n");
5647+
if !last_params.empty() {
5648+
for last_params do (cur) {
5649+
if cur.get_declaration().has_initializer() {
5650+
// TODO: Handle no type and no initializer. Handle passing style.
5651+
diff.add_reverse_primal("(cur.get_declaration().name())$: (cur.get_declaration().type())$ = (cur.get_declaration().get_initializer().to_string())$;\n");
5652+
rws_restore += "cpp2::ad_stack::push<(cur.get_declaration().type())$>((cur.get_declaration().name())$);\n";
5653+
rws += "(to_string_view(cur.get_passing_style()))$ (cur.get_declaration().name())$: (cur.get_declaration().type())$ = cpp2::ad_stack::pop<(cur.get_declaration().type())$>(), ";
5654+
}
5655+
}
5656+
diff.add_forward(diff_params.fwd);
55795657
}
5580-
diff += "copy (param_decl.name())$_d_iter := (range)$(ctx*.fwd_suffix)$.begin())\n";
5581-
diff += "for (range)$ next (";
5658+
diff.add_forward("copy (param_decl.name())$(ctx*.fwd_suffix)$_iter := (range)$(ctx*.fwd_suffix)$.begin())\n");
5659+
diff.add_forward("for (range)$ next (");
5660+
5661+
rws += "copy (param_decl.name())$(ctx*.rws_suffix)$_iter := (range)$(ctx*.rws_suffix)$.rbegin())\n";
5662+
rws += "for std::ranges::reverse_view((range)$) next (";
5663+
diff.add_reverse_primal("for (range)$ next (");
55825664
if stmt.has_next() {
55835665
// TODO: Assumption is here that nothing is in the next expression
5584-
diff += "(stmt.get_next_expression().to_string())$, ";
5666+
diff.add_forward("(stmt.get_next_expression().to_string())$, ");
5667+
diff.add_reverse_primal("(stmt.get_next_expression().to_string())$, ");
5668+
rws += "(reverse_next(stmt.get_next_expression().to_string()))$, ";
55855669
}
5586-
diff += "(param_decl.name())$_d_iter++";
5587-
diff += ") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n";
5588-
diff += "((param_style)$ (param_decl.name())$(ctx*.fwd_suffix)$: (param_decl.type())$ = (param_decl.name())$_d_iter*)";
5670+
diff.add_forward("(param_decl.name())$(ctx*.fwd_suffix)$_iter++");
5671+
diff.add_forward(") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n");
5672+
rws += "(param_decl.name())$(ctx*.rws_suffix)$_iter++";
5673+
rws += ") do ((param_style)$ (param_decl.name())$: (param_decl.type())$) {\n";
5674+
rws += "(inout (param_decl.name())$(ctx*.rws_suffix)$ := (param_decl.name())$(ctx*.rws_suffix)$_iter*)\n";
5675+
5676+
diff.add_reverse_primal(") do ((param_style)$ (param_decl.name())$: (param_decl.type())$)");
5677+
diff.add_forward("((param_style)$ (param_decl.name())$(ctx*.fwd_suffix)$: (param_decl.type())$ = (param_decl.name())$(ctx*.fwd_suffix)$_iter*)");
55895678

55905679
ctx*.add_variable_declaration("(param_decl.name())$", "(param_decl.type())$", true); // TODO: Handle loop/compound context variable declarations.
5680+
diff.add_reverse_backprop("}\n");
55915681

55925682
pre_traverse(stmt.get_for_body());
5593-
diff += "}\n";
5683+
diff.add_forward("}\n");
5684+
5685+
if stmt.has_next() {
5686+
diff.add_reverse_primal("(reverse_next(stmt.get_next_expression().to_string()))$;\n");
5687+
}
5688+
diff.add_reverse_primal(rws_restore);
5689+
diff.add_reverse_primal("}\n");
5690+
diff.add_reverse_backprop(rws);
55945691
}
55955692
}
55965693

@@ -5626,8 +5723,34 @@ autodiff_stmt_handler: type = {
56265723

56275724
h: autodiff_expression_handler = (ctx);
56285725
h.pre_traverse(assignment_terms[1].get_term());
5629-
h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
5630-
append(h);
5726+
5727+
is_overwrite := h.primal_expr.contains(h_lhs.primal_expr);
5728+
if overwrite_push_pop && is_overwrite {
5729+
r := ctx*.lookup_variable_declaration(h_lhs.primal_expr);
5730+
diff.add_reverse_primal("cpp2::ad_stack::push<(r.decl)$>((h_lhs.primal_expr)$);");
5731+
}
5732+
5733+
if is_overwrite && ctx*.is_reverse() {
5734+
t_b := ctx*.gen_temporary() + ctx*.rws_suffix;
5735+
h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, t_b);
5736+
append(h);
5737+
diff.add_reverse_backprop("(h_lhs.rws_expr)$ = 0.0;\n");
5738+
diff.add_reverse_backprop("(t_b)$ := (h_lhs.rws_expr)$;\n");
5739+
}
5740+
else {
5741+
h.gen_assignment(h_lhs.primal_expr, h_lhs.fwd_expr, h_lhs.rws_expr);
5742+
append(h);
5743+
}
5744+
5745+
if overwrite_push_pop && is_overwrite {
5746+
r := ctx*.lookup_variable_declaration(h_lhs.primal_expr);
5747+
diff.add_reverse_backprop("(h_lhs.primal_expr)$ = cpp2::ad_stack::pop<(r.decl)$>();");
5748+
}
5749+
5750+
// Simple overwrite check
5751+
if is_overwrite {
5752+
overwritten.push_back(h_lhs.primal_expr);
5753+
}
56315754
}
56325755
else {
56335756
diff.add_forward(binexpr.to_string() + ";\n");
@@ -6046,6 +6169,9 @@ autodiff: (inout t: meta::type_declaration) =
60466169
if 1 != order {
60476170
t.add_runtime_support_include( "cpp2taylor.h" );
60486171
}
6172+
if reverse {
6173+
t.add_runtime_support_include( "cpp2ad_stack.h" );
6174+
}
60496175

60506176
ad_ctx.finish();
60516177

0 commit comments

Comments
 (0)