Skip to content

Commit d38145c

Browse files
authored
Expose the query ID and the last provisional value to the cycle recovery function (#1012)
* Expose the query ID and the last provisional value to the cycle recovery function * Mark cycle as converged if fallback value is the same as the last provisional * Make `cycle_fn` optional
1 parent 16d51d6 commit d38145c

33 files changed

+127
-281
lines changed

benches/dataflow.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type {
7676

7777
fn def_cycle_recover(
7878
_db: &dyn Db,
79+
_id: salsa::Id,
80+
_last_provisional_value: &Type,
7981
value: &Type,
8082
count: u32,
8183
_def: Definition,
@@ -89,6 +91,8 @@ fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type {
8991

9092
fn use_cycle_recover(
9193
_db: &dyn Db,
94+
_id: salsa::Id,
95+
_last_provisional_value: &Type,
9296
value: &Type,
9397
count: u32,
9498
_use: Use,

book/src/cycles.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ fn initial(_db: &dyn KnobsDatabase) -> u32 {
2121
}
2222
```
2323

24+
The `cycle_fn` is optional. The default implementation always returns `Iterate`.
25+
2426
If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should resume starting with the given value (which should be a value that will converge quickly).
2527

2628
The cycle will iterate until it converges: that is, until two successive iterations produce the same result.

components/salsa-macro-rules/src/setup_tracked_fn.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,13 @@ macro_rules! setup_tracked_fn {
308308

309309
fn recover_from_cycle<$db_lt>(
310310
db: &$db_lt dyn $Db,
311+
id: salsa::Id,
312+
last_provisional_value: &Self::Output<$db_lt>,
311313
value: &Self::Output<$db_lt>,
312-
count: u32,
314+
iteration_count: u32,
313315
($($input_id),*): ($($interned_input_ty),*)
314316
) -> $zalsa::CycleRecoveryAction<Self::Output<$db_lt>> {
315-
$($cycle_recovery_fn)*(db, value, count, $($input_id),*)
317+
$($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*)
316318
}
317319

318320
fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> {

components/salsa-macro-rules/src/unexpected_cycle_recovery.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
// a macro because it can take a variadic number of arguments.
44
#[macro_export]
55
macro_rules! unexpected_cycle_recovery {
6-
($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{
7-
std::mem::drop($db);
6+
($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{
7+
let (_db, _id, _last_provisional_value, _new_value, _count) = ($db, $id, $last_provisional_value, $new_value, $count);
88
std::mem::drop(($($other_inputs,)*));
9-
panic!("cannot recover from cycle")
9+
salsa::CycleRecoveryAction::Iterate
1010
}};
1111
}
1212

components/salsa-macros/src/tracked_fn.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,10 @@ impl Macro {
286286
self.args.cycle_fn.as_ref().unwrap(),
287287
"must provide `cycle_initial` along with `cycle_fn`",
288288
)),
289-
(None, Some(_), None) => Err(syn::Error::new_spanned(
290-
self.args.cycle_initial.as_ref().unwrap(),
291-
"must provide `cycle_fn` along with `cycle_initial`",
289+
(None, Some(cycle_initial), None) => Ok((
290+
quote!((salsa::plumbing::unexpected_cycle_recovery!)),
291+
quote!((#cycle_initial)),
292+
quote!(Fixpoint),
292293
)),
293294
(None, None, Some(cycle_result)) => Ok((
294295
quote!((salsa::plumbing::unexpected_cycle_recovery!)),

src/cycle.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ pub enum CycleRecoveryAction<T> {
7070
/// Iterate the cycle again to look for a fixpoint.
7171
Iterate,
7272

73-
/// Cut off iteration and use the given result value for this query.
73+
/// Use the given value as the result for the current iteration instead
74+
/// of the value computed by the query function.
75+
///
76+
/// Returning `Fallback` doesn't stop the fixpoint iteration. It only
77+
/// allows the iterate function to return a different value.
7478
Fallback(T),
7579
}
7680

src/function.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,32 @@ pub trait Configuration: Any {
9494
/// value from the latest iteration of this cycle. `count` is the number of cycle iterations
9595
/// completed so far.
9696
///
97-
/// # Iteration count semantics
97+
/// # Id
9898
///
99-
/// The `count` parameter isn't guaranteed to start from zero or to be contiguous:
99+
/// The id can be used to uniquely identify the query instance. This can be helpful
100+
/// if the cycle function has to re-identify a value it returned previously.
100101
///
101-
/// * **Initial value**: `count` may be non-zero on the first call for a given query if that
102+
/// # Values
103+
///
104+
/// The `last_provisional_value` is the value from the previous iteration of this cycle
105+
/// and `value` is the new value that was computed in the current iteration.
106+
///
107+
/// # Iteration count
108+
///
109+
/// The `iteration` parameter isn't guaranteed to start from zero or to be contiguous:
110+
///
111+
/// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that
102112
/// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case,
103-
/// `count` continues from the nested cycle's iteration count rather than resetting to zero.
113+
/// `iteration` continues from the nested cycle's iteration count rather than resetting to zero.
104114
/// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle
105115
/// and the value for this query remains unchanged for one iteration. But the outer cycle might
106116
/// keep iterating because other heads keep changing.
107117
fn recover_from_cycle<'db>(
108118
db: &'db Self::DbView,
109-
value: &Self::Output<'db>,
110-
count: u32,
119+
id: Id,
120+
last_provisional_value: &Self::Output<'db>,
121+
new_value: &Self::Output<'db>,
122+
iteration: u32,
111123
input: Self::Input<'db>,
112124
) -> CycleRecoveryAction<Self::Output<'db>>;
113125

src/function/execute.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ where
320320
I am a cycle head, comparing last provisional value with new value"
321321
);
322322

323-
let this_converged = C::values_equal(&new_value, last_provisional_value);
323+
let mut this_converged = C::values_equal(&new_value, last_provisional_value);
324324

325325
// If this is the outermost cycle, use the maximum iteration count of all cycles.
326326
// This is important for when later iterations introduce new cycle heads (that then
@@ -341,6 +341,8 @@ where
341341
// cycle-recovery function what to do:
342342
match C::recover_from_cycle(
343343
db,
344+
id,
345+
last_provisional_value,
344346
&new_value,
345347
iteration_count.as_u32(),
346348
C::id_to_input(zalsa, id),
@@ -351,6 +353,8 @@ where
351353
"{database_key_index:?}: execute: user cycle_fn says to fall back"
352354
);
353355
new_value = fallback_value;
356+
357+
this_converged = C::values_equal(&new_value, last_provisional_value);
354358
}
355359
}
356360
}

src/function/memo.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ mod _memory_usage {
557557

558558
fn recover_from_cycle<'db>(
559559
_: &'db Self::DbView,
560+
_: Id,
561+
_: &Self::Output<'db>,
560562
_: &Self::Output<'db>,
561563
_: u32,
562564
_: Self::Input<'db>,

tests/backtrace.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ fn query_f(db: &dyn Database, thing: Thing) -> String {
4242
query_cycle(db, thing)
4343
}
4444

45-
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
45+
#[salsa::tracked(cycle_initial=cycle_initial)]
4646
fn query_cycle(db: &dyn Database, thing: Thing) -> String {
4747
let backtrace = query_cycle(db, thing);
4848
if backtrace.is_empty() {
@@ -56,15 +56,6 @@ fn cycle_initial(_db: &dyn salsa::Database, _thing: Thing) -> String {
5656
String::new()
5757
}
5858

59-
fn cycle_fn(
60-
_db: &dyn salsa::Database,
61-
_value: &str,
62-
_count: u32,
63-
_thing: Thing,
64-
) -> salsa::CycleRecoveryAction<String> {
65-
salsa::CycleRecoveryAction::Iterate
66-
}
67-
6859
#[test]
6960
fn backtrace_works() {
7061
let db = DatabaseImpl::default();

0 commit comments

Comments
 (0)