Skip to content

Commit dfa41d8

Browse files
committed
Reorganize code
1 parent 6dcd10c commit dfa41d8

File tree

5 files changed

+120
-118
lines changed

5 files changed

+120
-118
lines changed

src/graph/mod.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use std::collections::HashMap;
2+
use std::fmt::Debug;
3+
use std::hash::Hash;
4+
5+
#[derive(Clone, Debug)]
6+
pub struct Node {
7+
pub supply: f64,
8+
}
9+
10+
#[derive(Clone, Debug)]
11+
pub struct Edge {
12+
pub start: usize,
13+
pub end: usize,
14+
pub capacity: f64,
15+
pub cost: f64,
16+
}
17+
18+
#[derive(Clone, Debug)]
19+
pub struct Graph {
20+
pub nodes: Vec<Node>,
21+
pub edges: Vec<Edge>,
22+
}
23+
24+
#[derive(Clone, Debug)]
25+
pub struct GraphBuilder<T: Eq + Hash + Debug> {
26+
pub nodes: Vec<Node>,
27+
pub node_label_to_index: HashMap<T, usize>,
28+
pub edges: Vec<Edge>,
29+
}
30+
31+
impl<T: Eq + Hash + Debug> GraphBuilder<T> {
32+
pub fn new() -> GraphBuilder<T> {
33+
GraphBuilder {
34+
nodes: vec![],
35+
node_label_to_index: HashMap::new(),
36+
edges: vec![],
37+
}
38+
}
39+
40+
pub fn add_node(&mut self, label: T, supply: f64) {
41+
self.nodes.push(Node {
42+
supply
43+
});
44+
self.node_label_to_index.insert(label, self.nodes.len() - 1);
45+
}
46+
47+
pub fn add_edge(&mut self, label_u: T, label_v: T, capacity: f64, cost: f64) {
48+
let u = self.get_node_or_create(label_u);
49+
let v = self.get_node_or_create(label_v);
50+
self.edges.push(Edge {
51+
start: u,
52+
end: v,
53+
capacity,
54+
cost,
55+
});
56+
}
57+
58+
pub fn get_node_or_create(&mut self, label: T) -> usize {
59+
match self.node_label_to_index.get(&label) {
60+
Some(u) => {
61+
*u
62+
}
63+
None => {
64+
self.nodes.push(Node {
65+
supply: 0.0
66+
});
67+
self.node_label_to_index.insert(label, self.nodes.len() - 1);
68+
self.nodes.len() - 1
69+
}
70+
}
71+
}
72+
73+
pub fn build(&self) -> Graph {
74+
Graph {
75+
nodes: self.nodes.clone(),
76+
edges: self.edges.clone(),
77+
}
78+
}
79+
}
80+
81+
impl Graph {
82+
pub fn new() -> Graph {
83+
Graph {
84+
nodes: vec![],
85+
edges: vec![],
86+
}
87+
}
88+
89+
pub fn add_node(&mut self, supply: f64) {
90+
self.nodes.push(Node {
91+
supply
92+
});
93+
}
94+
95+
pub fn add_edge(&mut self, u: usize, v: usize, capacity: f64, cost: f64) {
96+
self.edges.push(Edge {
97+
start: u,
98+
end: v,
99+
capacity,
100+
cost,
101+
});
102+
}
103+
}

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
pub mod lp;
1+
pub mod ns;
2+
pub mod graph;

src/ns/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pub mod network_simplex;
2+
pub mod optimal_transport;
3+

src/lp/network_simplex.rs renamed to src/ns/network_simplex.rs

Lines changed: 5 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,11 @@
11
/*
22
Network simplex algorithm, adapted from networkx library implementation (https://networkx.org/documentation/stable/_modules/networkx/algorithms/flow/networksimplex.html)
33
*/
4-
use std::collections::HashMap;
54
use std::fmt::Debug;
65
use std::hash::Hash;
76
use std::iter::repeat;
87
use std::usize;
9-
10-
#[derive(Clone, Debug)]
11-
struct Node {
12-
supply: f64,
13-
}
14-
15-
#[derive(Clone, Debug)]
16-
struct Edge {
17-
start: usize,
18-
end: usize,
19-
capacity: f64,
20-
cost: f64,
21-
}
22-
23-
#[derive(Clone, Debug)]
24-
pub struct Graph {
25-
nodes: Vec<Node>,
26-
edges: Vec<Edge>,
27-
}
28-
29-
#[derive(Clone, Debug)]
30-
pub struct GraphBuilder<T: Eq + Hash + Debug> {
31-
nodes: Vec<Node>,
32-
node_label_to_index: HashMap<T, usize>,
33-
edges: Vec<Edge>,
34-
}
35-
36-
impl<T: Eq + Hash + Debug> GraphBuilder<T> {
37-
pub fn new() -> GraphBuilder<T> {
38-
GraphBuilder {
39-
nodes: vec![],
40-
node_label_to_index: HashMap::new(),
41-
edges: vec![],
42-
}
43-
}
44-
45-
pub fn add_node(&mut self, label: T, supply: f64) {
46-
self.nodes.push(Node {
47-
supply
48-
});
49-
self.node_label_to_index.insert(label, self.nodes.len() - 1);
50-
}
51-
52-
pub fn add_edge(&mut self, label_u: T, label_v: T, capacity: f64, cost: f64) {
53-
let u = self.get_node_or_create(label_u);
54-
let v = self.get_node_or_create(label_v);
55-
self.edges.push(Edge {
56-
start: u,
57-
end: v,
58-
capacity,
59-
cost,
60-
});
61-
}
62-
63-
fn get_node_or_create(&mut self, label: T) -> usize {
64-
match self.node_label_to_index.get(&label) {
65-
Some(u) => {
66-
*u
67-
}
68-
None => {
69-
self.nodes.push(Node {
70-
supply: 0.0
71-
});
72-
self.node_label_to_index.insert(label, self.nodes.len() - 1);
73-
self.nodes.len() - 1
74-
}
75-
}
76-
}
77-
78-
fn build(&self) -> Graph {
79-
Graph {
80-
nodes: self.nodes.clone(),
81-
edges: self.edges.clone(),
82-
}
83-
}
84-
}
85-
86-
impl Graph {
87-
pub fn new() -> Graph {
88-
Graph {
89-
nodes: vec![],
90-
edges: vec![],
91-
}
92-
}
93-
94-
pub fn add_node(&mut self, supply: f64) {
95-
self.nodes.push(Node {
96-
supply
97-
});
98-
}
99-
100-
pub fn add_edge(&mut self, u: usize, v: usize, capacity: f64, cost: f64) {
101-
self.edges.push(Edge {
102-
start: u,
103-
end: v,
104-
capacity,
105-
cost,
106-
});
107-
}
108-
}
8+
use crate::graph::{Edge, Graph};
1099

11010
#[derive(Debug)]
11111
struct Solution<'a> {
@@ -474,7 +374,7 @@ fn argmin<S: Copy>(edges: impl Iterator<Item=S>, func: impl Fn(S) -> f64) -> Opt
474374
argmin
475375
}
476376

477-
pub fn network_simplex(graph: &Graph, eps: f64) -> Vec<f64> {
377+
pub fn solve_min_cost_flow(graph: &Graph, eps: f64) -> Vec<f64> {
478378
let mut graph = graph.clone();
479379
let mut solution = Solution::new(&mut graph, eps);
480380

@@ -512,10 +412,7 @@ pub fn network_simplex(graph: &Graph, eps: f64) -> Vec<f64> {
512412

513413
#[cfg(test)]
514414
mod tests {
515-
use ndarray_rand::rand;
516-
use ndarray_rand::rand::random;
517-
518-
use crate::lp::network_simplex::{Graph, GraphBuilder, network_simplex};
415+
use crate::graph::GraphBuilder;
519416

520417
#[test]
521418
fn test_ns() {
@@ -526,7 +423,7 @@ mod tests {
526423
graph.add_edge(String::from("a"), String::from("c"), 10.0, 6.0);
527424
graph.add_edge(String::from("b"), String::from("d"), 9.0, 1.0);
528425
graph.add_edge(String::from("c"), String::from("d"), 5.0, 2.0);
529-
let flow = super::network_simplex(&graph.build(), 10e-12);
426+
let flow = super::solve_min_cost_flow(&graph.build(), 10e-12);
530427
assert_eq!(vec![4.0, 1.0, 4.0, 1.0], flow);
531428
}
532429

@@ -546,7 +443,7 @@ mod tests {
546443
graph.add_edge(String::from("a"), String::from("t"), 4.0, 2.0);
547444
graph.add_edge(String::from("d"), String::from("w"), 4.0, 3.0);
548445
graph.add_edge(String::from("t"), String::from("w"), 1.0, 4.0);
549-
let flow = super::network_simplex(&graph.build(), 10e-12);
446+
let flow = super::solve_min_cost_flow(&graph.build(), 10e-12);
550447
assert_eq!(vec![2.0, 2.0, 1.0, 1.0, 4.0, 2.0, 1.0], flow);
551448
}
552449
}

src/lp/mod.rs renamed to src/ns/optimal_transport.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
mod network_simplex;
2-
31
use ndarray::prelude::*;
4-
use crate::lp::network_simplex::{Graph, network_simplex};
2+
use crate::graph::Graph;
3+
use crate::ns::network_simplex::solve_min_cost_flow;
54

6-
pub fn solve(u: &Array1<f64>, v: &Array1<f64>, cost_matrix: &Array2<f64>, eps: f64) -> Array2<f64> {
5+
pub fn solve_ot(u: &Array1<f64>, v: &Array1<f64>, cost_matrix: &Array2<f64>, eps: f64) -> Array2<f64> {
76
let mut graph: Graph = Graph::new();
87
let m = u.len();
98
let n = v.len();
@@ -21,7 +20,7 @@ pub fn solve(u: &Array1<f64>, v: &Array1<f64>, cost_matrix: &Array2<f64>, eps: f
2120
}
2221
}
2322

24-
let flow = Array1::from_vec(network_simplex(&graph, eps));
23+
let flow = Array1::from_vec(solve_min_cost_flow(&graph, eps));
2524

2625
let mut result: Array2<f64> = Array::zeros((m, n));
2726

@@ -41,7 +40,7 @@ mod tests {
4140
use ndarray_rand::rand::SeedableRng;
4241
use ndarray_rand::rand_distr::Normal;
4342
use rand_isaac::isaac64::Isaac64Rng;
44-
use crate::lp::solve;
43+
use crate::ns::optimal_transport::solve_ot;
4544

4645
#[test]
4746
fn test_solve() {
@@ -55,12 +54,11 @@ mod tests {
5554

5655
let u: Array1<f64> = Array1::ones(n) / (n as f64);
5756

58-
let result = solve(&u, &u, &cost_matrix, 1e-8);
57+
let result = solve_ot(&u, &u, &cost_matrix, 1e-8);
5958

6059

6160

6261
assert_abs_diff_eq!(result.sum_axis(Axis(0)), u);
6362
assert_abs_diff_eq!(result.sum_axis(Axis(1)), u);
6463
}
65-
}
66-
64+
}

0 commit comments

Comments
 (0)