Skip to content

Commit 79aba53

Browse files
committed
Merge branch 'develop'
Release 0.1.1
2 parents 65b4a1e + 1189f3c commit 79aba53

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "ndarray-vision"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
authors = ["xd009642 <danielmckenna93@gmail.com>"]
55
description = "A computer vision library built on top of ndarray"
66
repository = "https://github.com/xd009642/ndarray-vision"

src/processing/conv.rs

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ where
2020
fn conv2d_inplace(&mut self, kernel: ArrayView3<Self::Data>) -> Result<(), Error>;
2121
}
2222

23+
fn kernel_centre(rows: usize, cols: usize) -> (usize, usize) {
24+
let row_offset = rows / 2 - ((rows % 2 == 0) as usize);
25+
let col_offset = cols / 2 - ((cols % 2 == 0) as usize);
26+
(row_offset, col_offset)
27+
}
28+
2329
impl<T> ConvolutionExt for Array3<T>
2430
where
2531
T: Copy + Clone + Num + NumAssignOps,
@@ -33,9 +39,7 @@ where
3339
let k_s = kernel.shape();
3440
// Bit icky but handles fact that uncentred convolutions will cross the bounds
3541
// otherwise
36-
let row_offset = k_s[0] / 2 - ((k_s[0] % 2 == 0) as usize);
37-
let col_offset = k_s[1] / 2 - ((k_s[1] % 2 == 0) as usize);
38-
42+
let (row_offset, col_offset) = kernel_centre(k_s[0], k_s[1]);
3943
// row_offset * 2 may not equal k_s[0] due to truncation
4044
let shape = (
4145
self.shape()[0] - row_offset * 2,
@@ -60,8 +64,14 @@ where
6064

6165
fn conv2d_inplace(&mut self, kernel: ArrayView3<Self::Data>) -> Result<(), Error> {
6266
let data = self.conv2d(kernel)?;
67+
let shape = kernel.shape();
68+
let centre = kernel_centre(shape[0], shape[1]);
6369
for (d, v) in self.indexed_iter_mut() {
64-
if let Some(d) = data.get(d) {
70+
if d.0 < centre.0 || d.1 < centre.1 {
71+
continue;
72+
}
73+
let centred = (d.0 - centre.0, d.1 - centre.1, d.2);
74+
if let Some(d) = data.get(centred) {
6575
*v = *d;
6676
}
6777
}
@@ -91,7 +101,8 @@ where
91101
#[cfg(test)]
92102
mod tests {
93103
use super::*;
94-
use crate::core::colour_models::RGB;
104+
use crate::core::colour_models::{Gray, RGB};
105+
use ndarray::arr3;
95106

96107
#[test]
97108
fn bad_dimensions() {
@@ -111,4 +122,51 @@ mod tests {
111122
assert!(i.conv2d(good_kern.view()).is_ok());
112123
assert!(i.conv2d_inplace(good_kern.view()).is_ok());
113124
}
125+
126+
#[test]
127+
fn basic_conv() {
128+
let input_pixels = vec![1, 1, 1, 0, 0,
129+
0, 1, 1, 1, 0,
130+
0, 0, 1, 1, 1,
131+
0, 0, 1, 1, 0,
132+
0, 1, 1, 0, 0];
133+
let output_pixels = vec![4, 3, 4,
134+
2, 4, 3,
135+
2, 3, 4];
136+
137+
let kern = arr3(&[[[1], [0], [1]],
138+
[[0], [1], [0]],
139+
[[1], [0], [1]]]);
140+
141+
let input = Image::<u8, Gray>::from_shape_data(5, 5, input_pixels);
142+
let expected = Image::<u8, Gray>::from_shape_data(3, 3, output_pixels);
143+
144+
assert_eq!(Ok(expected), input.conv2d(kern.view()));
145+
}
146+
147+
#[test]
148+
fn basic_conv_inplace() {
149+
let input_pixels = vec![1, 1, 1, 0, 0,
150+
0, 1, 1, 1, 0,
151+
0, 0, 1, 1, 1,
152+
0, 0, 1, 1, 0,
153+
0, 1, 1, 0, 0];
154+
155+
let output_pixels = vec![1, 1, 1, 0, 0,
156+
0, 4, 3, 4, 0,
157+
0, 2, 4, 3, 1,
158+
0, 2, 3, 4, 0,
159+
0, 1, 1, 0, 0];
160+
161+
let kern = arr3(&[[[1], [0], [1]],
162+
[[0], [1], [0]],
163+
[[1], [0], [1]]]);
164+
165+
let mut input = Image::<u8, Gray>::from_shape_data(5, 5, input_pixels);
166+
let expected = Image::<u8, Gray>::from_shape_data(5, 5, output_pixels);
167+
168+
input.conv2d_inplace(kern.view()).unwrap();
169+
170+
assert_eq!(expected, input);
171+
}
114172
}

0 commit comments

Comments
 (0)