Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ cargo add tree_traversal
use tree_traversal::bbs::bbs;

type Node = Vec<bool>;

fn main() {
let weights = [4, 2, 6, 3, 4];
let profits = [100, 20, 2, 5, 10];
Expand Down Expand Up @@ -79,8 +78,6 @@ fn main() {
s
};

// tree traversal assumes a minimization problem
// if you want to solve maximization problem, subtract your actual score from the MAX value
let lower_bound_fn = |n: &Node| {
let current_profit = total_profit(n);
let max_remained_profit: u32 = profits[n.len()..].into_iter().sum();
Expand All @@ -90,9 +87,17 @@ fn main() {
let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n));

let leaf_check_fn = |n: &Node| n.len() == total_items;

let (cost, best_node) =
bbs(vec![], successor_fn, lower_bound_fn, cost_fn, leaf_check_fn).unwrap();
let max_ops = usize::MAX;

let (cost, best_node) = bbs(
vec![],
successor_fn,
lower_bound_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();
let cost = u32::MAX - cost;

dbg!((best_node, cost));
Expand Down
12 changes: 10 additions & 2 deletions examples/bbs_knapsack_problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ fn main() {
let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n));

let leaf_check_fn = |n: &Node| n.len() == total_items;
let max_ops = usize::MAX;

let (cost, best_node) =
bbs(vec![], successor_fn, lower_bound_fn, cost_fn, leaf_check_fn).unwrap();
let (cost, best_node) = bbs(
vec![],
successor_fn,
lower_bound_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();
let cost = u32::MAX - cost;

dbg!((best_node, cost));
Expand Down
24 changes: 21 additions & 3 deletions src/bbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct BbsReachable<N, FN, FC, C> {
successor_fn: FN,
lower_bound_fn: FC,
current_best_cost: C,
remained_ops: usize,
}

impl<N, FN, IN, FC, C> Iterator for BbsReachable<N, FN, FC, C>
Expand All @@ -22,6 +23,10 @@ where
type Item = N;

fn next(&mut self) -> Option<Self::Item> {
if self.remained_ops == 0 {
return None;
}
self.remained_ops -= 1;
if let Some(n) = self.to_see.pop() {
// get lower bound
if let Some(lb) = (self.lower_bound_fn)(&n) {
Expand Down Expand Up @@ -53,6 +58,7 @@ pub fn bbs_reach<N, FN, IN, FC, C>(
start: N,
successor_fn: FN,
lower_bound_fn: FC,
max_ops: usize,
) -> BbsReachable<N, FN, FC, C>
where
N: Clone,
Expand All @@ -66,6 +72,7 @@ where
successor_fn,
lower_bound_fn,
current_best_cost: C::max_value(),
remained_ops: max_ops,
}
}

Expand All @@ -76,6 +83,7 @@ where
/// - `lower_bound_fn` returns the lower bound of a given node do decide wheather search deeper or not
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn bbs<N, IN, FN, FC1, FC2, C, FR>(
Expand All @@ -84,6 +92,7 @@ pub fn bbs<N, IN, FN, FC1, FC2, C, FR>(
lower_bound_fn: FC1,
cost_fn: FC2,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -94,7 +103,7 @@ where
C: Ord + Copy + Bounded,
FR: Fn(&N) -> bool,
{
let mut res = bbs_reach(start, successor_fn, lower_bound_fn);
let mut res = bbs_reach(start, successor_fn, lower_bound_fn, max_ops);
let mut best_leaf_node = None;
loop {
let op_n = res.next();
Expand Down Expand Up @@ -186,8 +195,17 @@ mod test {

let leaf_check_fn = |n: &Node| n.len() == total_items;

let (cost, best_node) =
bbs(vec![], successor_fn, lower_bound_fn, cost_fn, leaf_check_fn).unwrap();
let max_ops = usize::MAX;

let (cost, best_node) = bbs(
vec![],
successor_fn,
lower_bound_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();
let cost = u32::MAX - cost;

assert_eq!(cost, 120);
Expand Down
6 changes: 5 additions & 1 deletion src/bfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ use crate::bms::bms;
/// - `successor_fn` returns a list of successors for a given node.
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn bfs<N, IN, FN, FC, C, FR>(
start: N,
successor_fn: FN,
cost_fn: FC,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -34,6 +36,7 @@ where
usize::MAX,
cost_fn,
leaf_check_fn,
max_ops,
)
}

Expand Down Expand Up @@ -99,8 +102,9 @@ mod test {
};

let leaf_check_fn = |n: &Node| n.len() == total_items;
let max_ops = usize::MAX;

let (cost, best_node) = bfs(vec![], successor_fn, cost_fn, leaf_check_fn).unwrap();
let (cost, best_node) = bfs(vec![], successor_fn, cost_fn, leaf_check_fn, max_ops).unwrap();
let cost = u32::MAX - cost;

assert_eq!(cost, 6);
Expand Down
20 changes: 19 additions & 1 deletion src/bms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct BmsReachable<N, FN, FC, C: Ord> {
eval_fn: FC,
branch_factor: usize,
beam_width: usize,
remained_ops: usize,
pool: BinaryHeap<ScoredItem<C, N>>,
}

Expand All @@ -54,6 +55,10 @@ where
type Item = N;

fn next(&mut self) -> Option<Self::Item> {
if self.remained_ops == 0 {
return None;
}
self.remained_ops -= 1;
if self.to_see.is_empty() {
let max_iter = std::cmp::min(self.pool.len(), self.beam_width);
for _ in 0..max_iter {
Expand Down Expand Up @@ -88,6 +93,7 @@ pub fn bms_reach<N, FN, IN, FC, C>(
eval_fn: FC,
branch_factor: usize,
beam_width: usize,
max_ops: usize,
) -> BmsReachable<N, FN, FC, C>
where
N: Clone,
Expand All @@ -102,6 +108,7 @@ where
eval_fn,
branch_factor,
beam_width,
remained_ops: max_ops,
pool: BinaryHeap::new(),
}
}
Expand All @@ -115,6 +122,7 @@ where
/// - `beam_width` decides muximum number of nodes at each depth.
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn bms<N, IN, FN, FC1, FC2, C, FR>(
Expand All @@ -125,6 +133,7 @@ pub fn bms<N, IN, FN, FC1, FC2, C, FR>(
beam_width: usize,
cost_fn: FC2,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -135,7 +144,14 @@ where
C: Ord + Copy + Bounded,
FR: Fn(&N) -> bool,
{
let mut res = bms_reach(start, successor_fn, eval_fn, branch_factor, beam_width);
let mut res = bms_reach(
start,
successor_fn,
eval_fn,
branch_factor,
beam_width,
max_ops,
);
let mut best_leaf_node = None;
let mut current_best_cost = C::max_value();
loop {
Expand Down Expand Up @@ -319,6 +335,7 @@ mod test {

let branch_factor = 10;
let beam_width = 5;
let max_ops = usize::MAX;
let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start));
let leaf_check_fn = |n: &Node| n.is_leaf();

Expand All @@ -330,6 +347,7 @@ mod test {
beam_width,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();

Expand Down
6 changes: 5 additions & 1 deletion src/dfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ use crate::bbs::bbs;
/// - `successor_fn` returns a list of successors for a given node.
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn dfs<N, IN, FN, FC, C, FR>(
start: N,
successor_fn: FN,
cost_fn: FC,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -32,6 +34,7 @@ where
|_| Some(C::min_value()),
cost_fn,
leaf_check_fn,
max_ops,
)
}

Expand Down Expand Up @@ -97,8 +100,9 @@ mod test {
};

let leaf_check_fn = |n: &Node| n.len() == total_items;
let max_ops = usize::MAX;

let (cost, best_node) = dfs(vec![], successor_fn, cost_fn, leaf_check_fn).unwrap();
let (cost, best_node) = dfs(vec![], successor_fn, cost_fn, leaf_check_fn, max_ops).unwrap();
let cost = u32::MAX - cost;

assert_eq!(cost, 6);
Expand Down
16 changes: 14 additions & 2 deletions src/gds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::bms::bms;
/// - `eval_fn` returns the approximated cost of a given node to sort and select k-best
/// - `cost_fn` returns the final cost of a leaf node
/// - `leaf_check_fn` check if a node is leaf or not
/// - `max_ops` is the maximum number of search operations to perform
///
/// This function returns Some of a tuple of (cost, leaf node) if found, otherwise returns None
pub fn gds<N, IN, FN, FC1, FC2, C, FR>(
Expand All @@ -19,6 +20,7 @@ pub fn gds<N, IN, FN, FC1, FC2, C, FR>(
eval_fn: FC1,
cost_fn: FC2,
leaf_check_fn: FR,
max_ops: usize,
) -> Option<(C, N)>
where
N: Clone,
Expand All @@ -37,6 +39,7 @@ where
1,
cost_fn,
leaf_check_fn,
max_ops,
)
}

Expand Down Expand Up @@ -173,8 +176,17 @@ mod test {
let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start));
let leaf_check_fn = |n: &Node| n.is_leaf();

let (cost, best_node) =
gds(root_node, successor_fn, eval_fn, cost_fn, leaf_check_fn).unwrap();
let max_ops = usize::MAX;

let (cost, best_node) = gds(
root_node,
successor_fn,
eval_fn,
cost_fn,
leaf_check_fn,
max_ops,
)
.unwrap();

assert!(cost < 9000);
let mut visited_cities = best_node.parents.clone();
Expand Down