Skip to content

Commit e1f11f7

Browse files
author
maxtremblay
committed
Implement kronecker product
1 parent 8aa01f1 commit e1f11f7

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

src/matrix/kronecker.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use super::SparseBinMat;
2+
use crate::SparseBinSlice;
3+
4+
pub(super) fn kronecker_product(
5+
left_matrix: &SparseBinMat,
6+
right_matrix: &SparseBinMat,
7+
) -> SparseBinMat {
8+
let rows = left_matrix
9+
.rows()
10+
.flat_map(|row| kron_row(row, right_matrix))
11+
.collect();
12+
let number_of_columns = left_matrix.number_of_columns * right_matrix.number_of_columns();
13+
SparseBinMat::new(number_of_columns, rows)
14+
}
15+
16+
fn kron_row<'a>(
17+
left_row: SparseBinSlice<'a>,
18+
right_matrix: &'a SparseBinMat,
19+
) -> impl Iterator<Item = Vec<usize>> + 'a {
20+
right_matrix.rows().map(move |right_row| {
21+
left_row
22+
.non_trivial_positions()
23+
.flat_map(|position| pad_row(position * right_row.len(), &right_row))
24+
.collect()
25+
})
26+
}
27+
28+
fn pad_row<'a>(pad: usize, row: &'a SparseBinSlice<'a>) -> impl Iterator<Item = usize> + 'a {
29+
row.non_trivial_positions()
30+
.map(move |position| position + pad)
31+
}
32+
33+
#[cfg(test)]
34+
mod test {
35+
use super::*;
36+
37+
#[test]
38+
fn left_kron_with_identity() {
39+
let matrix = SparseBinMat::new(4, vec![vec![0, 2], vec![1, 3]]);
40+
let product = matrix.kron_with(&SparseBinMat::identity(2));
41+
let expected = SparseBinMat::new(8, vec![vec![0, 4], vec![1, 5], vec![2, 6], vec![3, 7]]);
42+
assert_eq!(product, expected);
43+
}
44+
45+
#[test]
46+
fn right_kron_with_identity() {
47+
let matrix = SparseBinMat::new(4, vec![vec![0, 2], vec![1, 3]]);
48+
let product = SparseBinMat::identity(2).kron_with(&matrix);
49+
let expected = SparseBinMat::new(8, vec![vec![0, 2], vec![1, 3], vec![4, 6], vec![5, 7]]);
50+
assert_eq!(product, expected);
51+
}
52+
53+
#[test]
54+
fn kron_with_itself() {
55+
let matrix = SparseBinMat::new(4, vec![vec![0, 2], vec![1, 3]]);
56+
let product = matrix.kron_with(&matrix);
57+
let expected = SparseBinMat::new(
58+
16,
59+
vec![
60+
vec![0, 2, 8, 10],
61+
vec![1, 3, 9, 11],
62+
vec![4, 6, 12, 14],
63+
vec![5, 7, 13, 15],
64+
],
65+
);
66+
assert_eq!(product, expected);
67+
}
68+
}

src/matrix/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ use constructor_utils::initialize_from;
1717
mod gauss_jordan;
1818
use gauss_jordan::GaussJordan;
1919

20+
mod kronecker;
21+
use kronecker::kronecker_product;
22+
2023
mod non_trivial_elements;
2124
pub use self::non_trivial_elements::NonTrivialElements;
2225

@@ -733,6 +736,24 @@ impl SparseBinMat {
733736
self.keep_only_columns(&to_keep)
734737
}
735738

739+
/// Returns the Kronecker product of two matrices.
740+
///
741+
/// # Example
742+
///
743+
/// ```
744+
/// # use sparse_bin_mat::SparseBinMat;
745+
/// let left_matrix = SparseBinMat::new(2, vec![vec![1], vec![0]]);
746+
/// let right_matrix = SparseBinMat::new(3, vec![vec![0, 1], vec![1, 2]]);
747+
///
748+
/// let product = left_matrix.kron_with(&right_matrix);
749+
/// let expected = SparseBinMat::new(6, vec![vec![3, 4], vec![4, 5], vec![0, 1], vec![1, 2]]);
750+
///
751+
/// assert_eq!(product, expected);
752+
/// ```
753+
pub fn kron_with(&self, other: &Self) -> Self {
754+
kronecker_product(self, other)
755+
}
756+
736757
/// Returns a json string for the matrix.
737758
pub fn as_json(&self) -> Result<String, serde_json::Error> {
738759
serde_json::to_string(self)

0 commit comments

Comments
 (0)