2020 fn conv2d_inplace ( & mut self , kernel : ArrayView3 < Self :: Data > ) -> Result < ( ) , Error > ;
2121}
2222
23+ fn kernel_centre ( rows : usize , cols : usize ) -> ( usize , usize ) {
24+ let row_offset = rows / 2 - ( ( rows % 2 == 0 ) as usize ) ;
25+ let col_offset = cols / 2 - ( ( cols % 2 == 0 ) as usize ) ;
26+ ( row_offset, col_offset)
27+ }
28+
2329impl < T > ConvolutionExt for Array3 < T >
2430where
2531 T : Copy + Clone + Num + NumAssignOps ,
3339 let k_s = kernel. shape ( ) ;
3440 // Bit icky but handles fact that uncentred convolutions will cross the bounds
3541 // otherwise
36- let row_offset = k_s[ 0 ] / 2 - ( ( k_s[ 0 ] % 2 == 0 ) as usize ) ;
37- let col_offset = k_s[ 1 ] / 2 - ( ( k_s[ 1 ] % 2 == 0 ) as usize ) ;
38-
42+ let ( row_offset, col_offset) = kernel_centre ( k_s[ 0 ] , k_s[ 1 ] ) ;
3943 // row_offset * 2 may not equal k_s[0] due to truncation
4044 let shape = (
4145 self . shape ( ) [ 0 ] - row_offset * 2 ,
6064
6165 fn conv2d_inplace ( & mut self , kernel : ArrayView3 < Self :: Data > ) -> Result < ( ) , Error > {
6266 let data = self . conv2d ( kernel) ?;
67+ let shape = kernel. shape ( ) ;
68+ let centre = kernel_centre ( shape[ 0 ] , shape[ 1 ] ) ;
6369 for ( d, v) in self . indexed_iter_mut ( ) {
64- if let Some ( d) = data. get ( d) {
70+ if d. 0 < centre. 0 || d. 1 < centre. 1 {
71+ continue ;
72+ }
73+ let centred = ( d. 0 - centre. 0 , d. 1 - centre. 1 , d. 2 ) ;
74+ if let Some ( d) = data. get ( centred) {
6575 * v = * d;
6676 }
6777 }
91101#[ cfg( test) ]
92102mod tests {
93103 use super :: * ;
94- use crate :: core:: colour_models:: RGB ;
104+ use crate :: core:: colour_models:: { Gray , RGB } ;
105+ use ndarray:: arr3;
95106
96107 #[ test]
97108 fn bad_dimensions ( ) {
@@ -111,4 +122,51 @@ mod tests {
111122 assert ! ( i. conv2d( good_kern. view( ) ) . is_ok( ) ) ;
112123 assert ! ( i. conv2d_inplace( good_kern. view( ) ) . is_ok( ) ) ;
113124 }
125+
126+ #[ test]
127+ fn basic_conv ( ) {
128+ let input_pixels = vec ! [ 1 , 1 , 1 , 0 , 0 ,
129+ 0 , 1 , 1 , 1 , 0 ,
130+ 0 , 0 , 1 , 1 , 1 ,
131+ 0 , 0 , 1 , 1 , 0 ,
132+ 0 , 1 , 1 , 0 , 0 ] ;
133+ let output_pixels = vec ! [ 4 , 3 , 4 ,
134+ 2 , 4 , 3 ,
135+ 2 , 3 , 4 ] ;
136+
137+ let kern = arr3 ( & [ [ [ 1 ] , [ 0 ] , [ 1 ] ] ,
138+ [ [ 0 ] , [ 1 ] , [ 0 ] ] ,
139+ [ [ 1 ] , [ 0 ] , [ 1 ] ] ] ) ;
140+
141+ let input = Image :: < u8 , Gray > :: from_shape_data ( 5 , 5 , input_pixels) ;
142+ let expected = Image :: < u8 , Gray > :: from_shape_data ( 3 , 3 , output_pixels) ;
143+
144+ assert_eq ! ( Ok ( expected) , input. conv2d( kern. view( ) ) ) ;
145+ }
146+
147+ #[ test]
148+ fn basic_conv_inplace ( ) {
149+ let input_pixels = vec ! [ 1 , 1 , 1 , 0 , 0 ,
150+ 0 , 1 , 1 , 1 , 0 ,
151+ 0 , 0 , 1 , 1 , 1 ,
152+ 0 , 0 , 1 , 1 , 0 ,
153+ 0 , 1 , 1 , 0 , 0 ] ;
154+
155+ let output_pixels = vec ! [ 1 , 1 , 1 , 0 , 0 ,
156+ 0 , 4 , 3 , 4 , 0 ,
157+ 0 , 2 , 4 , 3 , 1 ,
158+ 0 , 2 , 3 , 4 , 0 ,
159+ 0 , 1 , 1 , 0 , 0 ] ;
160+
161+ let kern = arr3 ( & [ [ [ 1 ] , [ 0 ] , [ 1 ] ] ,
162+ [ [ 0 ] , [ 1 ] , [ 0 ] ] ,
163+ [ [ 1 ] , [ 0 ] , [ 1 ] ] ] ) ;
164+
165+ let mut input = Image :: < u8 , Gray > :: from_shape_data ( 5 , 5 , input_pixels) ;
166+ let expected = Image :: < u8 , Gray > :: from_shape_data ( 5 , 5 , output_pixels) ;
167+
168+ input. conv2d_inplace ( kern. view ( ) ) . unwrap ( ) ;
169+
170+ assert_eq ! ( expected, input) ;
171+ }
114172}
0 commit comments