Skip to content

Commit

Permalink
Add the logic to tighten-up the inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 18, 2024
1 parent a02a510 commit a003506
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
}
if (input_execs->rnum <= 0 || output_execs->rnum <= 0)
continue;
ccv_nnc_graph_visit_t* const reverse_visit = ccv_nnc_graph_visit_new(graph, reversed_nodes, exec_rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(output_execs, 0), output_execs->rnum, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(input_execs, 0), input_execs->rnum, 1);
ccv_nnc_graph_visit_for(reverse_visit, exec_info, node, idx) {
if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags))
maskbit[idx >> 5] |= (1u << (idx & 0x1f));
} ccv_nnc_graph_visit_endfor
// Check if any of the items in input_execs is not marked, if it is not, there is no connection from it to output_execs, no need to visit.
for (j = 0; j < input_execs->rnum;)
{
const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(input_execs, j))->d;
if (maskbit[idx >> 5] & (1u << (idx & 0x1f)))
{
++j;
continue;
}
// Not marked, remove this one.
if (j < input_execs->rnum - 1)
*(ccv_nnc_tensor_symbol_t*)ccv_array_get(input_execs, j) = *(ccv_nnc_tensor_symbol_t*)ccv_array_get(input_execs, input_execs->rnum - 1);
--input_execs->rnum;
}
// Reset maskbit back.
ccv_nnc_graph_visit_for(reverse_visit, exec_info, node, idx) {
if (idx < exec_rnum && !CCV_NNC_GRAPH_EXEC_IS_DEAD(node->flags))
maskbit[idx >> 5] &= ~(1u << (idx & 0x1f));
} ccv_nnc_graph_visit_endfor
ccv_nnc_graph_visit_free(reverse_visit);
// Fill in blanks (i.e. the backward ops that are not showing in above, but should be included to avoid excluding necessary ones). This is done by flowing gradients from outputs back all the way to inputs.
ccv_array_clear(input_gradient_execs);
ccv_array_clear(output_gradient_execs);
Expand Down

0 comments on commit a003506

Please sign in to comment.