Skip to content

Commit 2ca801c

Browse files
committed
ndarray-gen: Add simple internal interface for building matrices
1 parent 9f1b35d commit 2ca801c

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ members = [
8585
default-members = [
8686
".",
8787
"ndarray-rand",
88+
"crates/ndarray-gen",
8889
"crates/numeric-tests",
8990
"crates/serialization-tests",
9091
# exclude blas-tests that depends on BLAS install
@@ -93,6 +94,7 @@ default-members = [
9394
[workspace.dependencies]
9495
ndarray = { version = "0.16", path = "." }
9596
ndarray-rand = { path = "ndarray-rand" }
97+
ndarray-gen = { path = "crates/ndarray-gen" }
9698

9799
num-integer = { version = "0.1.39", default-features = false }
98100
num-traits = { version = "0.2", default-features = false }

crates/ndarray-gen/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "ndarray-gen"
3+
version = "0.1.0"
4+
edition = "2018"
5+
publish = false
6+
7+
[dependencies]
8+
ndarray = { workspace = true }
9+
num-traits = { workspace = true }

crates/ndarray-gen/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
## ndarray-gen
3+
4+
Array generation functions, used for testing.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2024 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use ndarray::Array;
10+
use ndarray::Dimension;
11+
use ndarray::IntoDimension;
12+
use ndarray::Order;
13+
14+
use num_traits::Num;
15+
16+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
17+
pub struct ArrayBuilder<D: Dimension>
18+
{
19+
dim: D,
20+
memory_order: Order,
21+
generator: ElementGenerator,
22+
}
23+
24+
/// How to generate elements
25+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
26+
pub enum ElementGenerator
27+
{
28+
Sequential,
29+
Zero,
30+
}
31+
32+
impl<D: Dimension> Default for ArrayBuilder<D>
33+
{
34+
fn default() -> Self
35+
{
36+
Self::new(D::zeros(D::NDIM.unwrap_or(1)))
37+
}
38+
}
39+
40+
impl<D> ArrayBuilder<D>
41+
where D: Dimension
42+
{
43+
pub fn new(dim: impl IntoDimension<Dim = D>) -> Self
44+
{
45+
ArrayBuilder {
46+
dim: dim.into_dimension(),
47+
memory_order: Order::C,
48+
generator: ElementGenerator::Sequential,
49+
}
50+
}
51+
52+
pub fn memory_order(mut self, order: Order) -> Self
53+
{
54+
self.memory_order = order;
55+
self
56+
}
57+
58+
pub fn generator(mut self, generator: ElementGenerator) -> Self
59+
{
60+
self.generator = generator;
61+
self
62+
}
63+
64+
pub fn build<T>(self) -> Array<T, D>
65+
where T: Num + Clone
66+
{
67+
let mut current = T::zero();
68+
let size = self.dim.size();
69+
let use_zeros = self.generator == ElementGenerator::Zero;
70+
Array::from_iter((0..size).map(|_| {
71+
let ret = current.clone();
72+
if !use_zeros {
73+
current = ret.clone() + T::one();
74+
}
75+
ret
76+
}))
77+
.into_shape_with_order((self.dim, self.memory_order))
78+
.unwrap()
79+
}
80+
}
81+
82+
#[test]
83+
fn test_order()
84+
{
85+
let (m, n) = (12, 13);
86+
let c = ArrayBuilder::new((m, n))
87+
.memory_order(Order::C)
88+
.build::<i32>();
89+
let f = ArrayBuilder::new((m, n))
90+
.memory_order(Order::F)
91+
.build::<i32>();
92+
93+
assert_eq!(c.shape(), &[m, n]);
94+
assert_eq!(f.shape(), &[m, n]);
95+
assert_eq!(c.strides(), &[n as isize, 1]);
96+
assert_eq!(f.strides(), &[1, m as isize]);
97+
}

crates/ndarray-gen/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright 2024 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
/// Build ndarray arrays for test purposes
10+
11+
pub mod array_builder;

0 commit comments

Comments
 (0)