|
| 1 | +//! # One-hot Encoding For [RealNumber](../../math/num/trait.RealNumber.html) Matricies |
| 2 | +//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by replacing all categorical variables with their one-hot equivalents |
| 3 | +//! |
| 4 | +//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [CategoryMapper](../series_encoder/struct.CategoryMapper.html) |
| 5 | +//! |
| 6 | +//! ### Usage Example |
| 7 | +//! ``` |
| 8 | +//! use smartcore::linalg::naive::dense_matrix::DenseMatrix; |
| 9 | +//! use smartcore::preprocessing::categorical::{OneHotEncoder, OneHotEncoderParams}; |
| 10 | +//! let data = DenseMatrix::from_2d_array(&[ |
| 11 | +//! &[1.5, 1.0, 1.5, 3.0], |
| 12 | +//! &[1.5, 2.0, 1.5, 4.0], |
| 13 | +//! &[1.5, 1.0, 1.5, 5.0], |
| 14 | +//! &[1.5, 2.0, 1.5, 6.0], |
| 15 | +//! ]); |
| 16 | +//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]); |
| 17 | +//! // Infer number of categories from data and return a reusable encoder |
| 18 | +//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap(); |
| 19 | +//! // Transform categorical to one-hot encoded (can transform similar) |
| 20 | +//! let oh_data = encoder.transform(&data).unwrap(); |
| 21 | +//! // Produces the following: |
| 22 | +//! // &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0] |
| 23 | +//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0] |
| 24 | +//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0] |
| 25 | +//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0] |
| 26 | +//! ``` |
| 27 | +use std::iter; |
| 28 | + |
| 29 | +use crate::error::Failed; |
| 30 | +use crate::linalg::Matrix; |
| 31 | + |
| 32 | +use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable}; |
| 33 | +use crate::preprocessing::series_encoder::CategoryMapper; |
| 34 | + |
| 35 | +/// OneHotEncoder Parameters |
| 36 | +#[derive(Debug, Clone)] |
| 37 | +pub struct OneHotEncoderParams { |
| 38 | + /// Column number that contain categorical variable |
| 39 | + pub col_idx_categorical: Option<Vec<usize>>, |
| 40 | + /// (Currently not implemented) Try and infer which of the matrix columns are categorical variables |
| 41 | + infer_categorical: bool, |
| 42 | +} |
| 43 | + |
| 44 | +impl OneHotEncoderParams { |
| 45 | + /// Generate parameters from categorical variable column numbers |
| 46 | + pub fn from_cat_idx(categorical_params: &[usize]) -> Self { |
| 47 | + Self { |
| 48 | + col_idx_categorical: Some(categorical_params.to_vec()), |
| 49 | + infer_categorical: false, |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +/// Calculate the offset to parameters to due introduction of one-hot encoding |
| 55 | +fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) -> Vec<usize> { |
| 56 | + // This functions uses iterators and returns a vector. |
| 57 | + // In case we get a huge amount of paramenters this might be a problem |
| 58 | + // todo: Change this such that it will return an iterator |
| 59 | + |
| 60 | + let cat_idx = cat_idxs.iter().copied().chain((num_params..).take(1)); |
| 61 | + |
| 62 | + // Offset is constant between two categorical values, here we calculate the number of steps |
| 63 | + // that remain constant |
| 64 | + let repeats = cat_idx.scan(0, |a, v| { |
| 65 | + let im = v + 1 - *a; |
| 66 | + *a = v; |
| 67 | + Some(im) |
| 68 | + }); |
| 69 | + |
| 70 | + // Calculate the offset to parameter idx due to newly intorduced one-hot vectors |
| 71 | + let offset_ = cat_sizes.iter().scan(0, |a, &v| { |
| 72 | + *a = *a + v - 1; |
| 73 | + Some(*a) |
| 74 | + }); |
| 75 | + let offset = (0..1).chain(offset_); |
| 76 | + |
| 77 | + let new_param_idxs: Vec<usize> = (0..num_params) |
| 78 | + .zip( |
| 79 | + repeats |
| 80 | + .zip(offset) |
| 81 | + .map(|(r, o)| iter::repeat(o).take(r)) |
| 82 | + .flatten(), |
| 83 | + ) |
| 84 | + .map(|(idx, ofst)| idx + ofst) |
| 85 | + .collect(); |
| 86 | + new_param_idxs |
| 87 | +} |
| 88 | + |
| 89 | +fn validate_col_is_categorical<T: Categorizable>(data: &[T]) -> bool { |
| 90 | + for v in data { |
| 91 | + if !v.is_valid() { |
| 92 | + return false; |
| 93 | + } |
| 94 | + } |
| 95 | + true |
| 96 | +} |
| 97 | + |
| 98 | +/// Encode Categorical variavbles of data matrix to one-hot |
| 99 | +#[derive(Debug, Clone)] |
| 100 | +pub struct OneHotEncoder { |
| 101 | + category_mappers: Vec<CategoryMapper<CategoricalFloat>>, |
| 102 | + col_idx_categorical: Vec<usize>, |
| 103 | +} |
| 104 | + |
| 105 | +impl OneHotEncoder { |
| 106 | + /// Create an encoder instance with categories infered from data matrix |
| 107 | + pub fn fit<T, M>(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed> |
| 108 | + where |
| 109 | + T: Categorizable, |
| 110 | + M: Matrix<T>, |
| 111 | + { |
| 112 | + match (params.col_idx_categorical, params.infer_categorical) { |
| 113 | + (None, false) => Err(Failed::fit( |
| 114 | + "Must pass categorical series ids or infer flag", |
| 115 | + )), |
| 116 | + |
| 117 | + (Some(_idxs), true) => Err(Failed::fit( |
| 118 | + "Ambigous parameters, got both infer and categroy ids", |
| 119 | + )), |
| 120 | + |
| 121 | + (Some(mut idxs), false) => { |
| 122 | + // make sure categories have same order as data columns |
| 123 | + idxs.sort_unstable(); |
| 124 | + |
| 125 | + let (nrows, _) = data.shape(); |
| 126 | + |
| 127 | + // col buffer to avoid allocations |
| 128 | + let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect(); |
| 129 | + |
| 130 | + let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len()); |
| 131 | + |
| 132 | + for &idx in &idxs { |
| 133 | + data.copy_col_as_vec(idx, &mut col_buf); |
| 134 | + if !validate_col_is_categorical(&col_buf) { |
| 135 | + let msg = format!( |
| 136 | + "Column {} of data matrix containts non categorizable (integer) values", |
| 137 | + idx |
| 138 | + ); |
| 139 | + return Err(Failed::fit(&msg[..])); |
| 140 | + } |
| 141 | + let hashable_col = col_buf.iter().map(|v| v.to_category()); |
| 142 | + res.push(CategoryMapper::fit_to_iter(hashable_col)); |
| 143 | + } |
| 144 | + |
| 145 | + Ok(Self { |
| 146 | + category_mappers: res, |
| 147 | + col_idx_categorical: idxs, |
| 148 | + }) |
| 149 | + } |
| 150 | + |
| 151 | + (None, true) => { |
| 152 | + todo!("Auto-Inference for Categorical Variables not yet implemented") |
| 153 | + } |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + /// Transform categorical variables to one-hot encoded and return a new matrix |
| 158 | + pub fn transform<T, M>(&self, x: &M) -> Result<M, Failed> |
| 159 | + where |
| 160 | + T: Categorizable, |
| 161 | + M: Matrix<T>, |
| 162 | + { |
| 163 | + let (nrows, p) = x.shape(); |
| 164 | + let additional_params: Vec<usize> = self |
| 165 | + .category_mappers |
| 166 | + .iter() |
| 167 | + .map(|enc| enc.num_categories()) |
| 168 | + .collect(); |
| 169 | + |
| 170 | + // Eac category of size v adds v-1 params |
| 171 | + let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1); |
| 172 | + |
| 173 | + let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]); |
| 174 | + let mut res = M::zeros(nrows, expandws_p); |
| 175 | + |
| 176 | + for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() { |
| 177 | + let cidx = new_col_idx[old_cidx]; |
| 178 | + let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category()); |
| 179 | + let sencoder = &self.category_mappers[pidx]; |
| 180 | + let oh_series = col_iter.map(|c| sencoder.get_one_hot::<T, Vec<T>>(&c)); |
| 181 | + |
| 182 | + for (row, oh_vec) in oh_series.enumerate() { |
| 183 | + match oh_vec { |
| 184 | + None => { |
| 185 | + // Since we support T types, bad value in a series causes in to be invalid |
| 186 | + let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx); |
| 187 | + return Err(Failed::transform(&msg[..])); |
| 188 | + } |
| 189 | + Some(v) => { |
| 190 | + // copy one hot vectors to their place in the data matrix; |
| 191 | + for (col_ofst, &val) in v.iter().enumerate() { |
| 192 | + res.set(row, cidx + col_ofst, val); |
| 193 | + } |
| 194 | + } |
| 195 | + } |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + // copy old data in x to their new location while skipping catergorical vars (already treated) |
| 200 | + let mut skip_idx_iter = self.col_idx_categorical.iter(); |
| 201 | + let mut cur_skip = skip_idx_iter.next(); |
| 202 | + |
| 203 | + for (old_p, &new_p) in new_col_idx.iter().enumerate() { |
| 204 | + // if found treated varible, skip it |
| 205 | + if let Some(&v) = cur_skip { |
| 206 | + if v == old_p { |
| 207 | + cur_skip = skip_idx_iter.next(); |
| 208 | + continue; |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + for r in 0..nrows { |
| 213 | + let val = x.get(r, old_p); |
| 214 | + res.set(r, new_p, val); |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + Ok(res) |
| 219 | + } |
| 220 | +} |
| 221 | + |
| 222 | +#[cfg(test)] |
| 223 | +mod tests { |
| 224 | + use super::*; |
| 225 | + use crate::linalg::naive::dense_matrix::DenseMatrix; |
| 226 | + use crate::preprocessing::series_encoder::CategoryMapper; |
| 227 | + |
| 228 | + #[test] |
| 229 | + fn adjust_idxs() { |
| 230 | + assert_eq!(find_new_idxs(0, &[], &[]), Vec::<usize>::new()); |
| 231 | + // [0,1,2] -> [0, 1, 1, 1, 2] |
| 232 | + assert_eq!(find_new_idxs(3, &[3], &[1]), vec![0, 1, 4]); |
| 233 | + } |
| 234 | + |
| 235 | + fn build_cat_first_and_last() -> (DenseMatrix<f64>, DenseMatrix<f64>) { |
| 236 | + let orig = DenseMatrix::from_2d_array(&[ |
| 237 | + &[1.0, 1.5, 3.0], |
| 238 | + &[2.0, 1.5, 4.0], |
| 239 | + &[1.0, 1.5, 5.0], |
| 240 | + &[2.0, 1.5, 6.0], |
| 241 | + ]); |
| 242 | + |
| 243 | + let oh_enc = DenseMatrix::from_2d_array(&[ |
| 244 | + &[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0], |
| 245 | + &[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0], |
| 246 | + &[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0], |
| 247 | + &[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0], |
| 248 | + ]); |
| 249 | + |
| 250 | + (orig, oh_enc) |
| 251 | + } |
| 252 | + |
| 253 | + fn build_fake_matrix() -> (DenseMatrix<f64>, DenseMatrix<f64>) { |
| 254 | + // Categorical first and last |
| 255 | + let orig = DenseMatrix::from_2d_array(&[ |
| 256 | + &[1.5, 1.0, 1.5, 3.0], |
| 257 | + &[1.5, 2.0, 1.5, 4.0], |
| 258 | + &[1.5, 1.0, 1.5, 5.0], |
| 259 | + &[1.5, 2.0, 1.5, 6.0], |
| 260 | + ]); |
| 261 | + |
| 262 | + let oh_enc = DenseMatrix::from_2d_array(&[ |
| 263 | + &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0], |
| 264 | + &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0], |
| 265 | + &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0], |
| 266 | + &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0], |
| 267 | + ]); |
| 268 | + |
| 269 | + (orig, oh_enc) |
| 270 | + } |
| 271 | + |
| 272 | + #[test] |
| 273 | + fn hash_encode_f64_series() { |
| 274 | + let series = vec![3.0, 1.0, 2.0, 1.0]; |
| 275 | + let hashable_series: Vec<CategoricalFloat> = |
| 276 | + series.iter().map(|v| v.to_category()).collect(); |
| 277 | + let enc = CategoryMapper::from_positional_category_vec(hashable_series); |
| 278 | + let inv = enc.invert_one_hot(vec![0.0, 0.0, 1.0]); |
| 279 | + let orig_val: f64 = inv.unwrap().into(); |
| 280 | + assert_eq!(orig_val, 2.0); |
| 281 | + } |
| 282 | + #[test] |
| 283 | + fn test_fit() { |
| 284 | + let (x, _) = build_fake_matrix(); |
| 285 | + let params = OneHotEncoderParams::from_cat_idx(&[1, 3]); |
| 286 | + let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); |
| 287 | + assert_eq!(oh_enc.category_mappers.len(), 2); |
| 288 | + |
| 289 | + let num_cat: Vec<usize> = oh_enc |
| 290 | + .category_mappers |
| 291 | + .iter() |
| 292 | + .map(|a| a.num_categories()) |
| 293 | + .collect(); |
| 294 | + assert_eq!(num_cat, vec![2, 4]); |
| 295 | + } |
| 296 | + |
| 297 | + #[test] |
| 298 | + fn matrix_transform_test() { |
| 299 | + let (x, expected_x) = build_fake_matrix(); |
| 300 | + let params = OneHotEncoderParams::from_cat_idx(&[1, 3]); |
| 301 | + let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); |
| 302 | + let nm = oh_enc.transform(&x).unwrap(); |
| 303 | + assert_eq!(nm, expected_x); |
| 304 | + |
| 305 | + let (x, expected_x) = build_cat_first_and_last(); |
| 306 | + let params = OneHotEncoderParams::from_cat_idx(&[0, 2]); |
| 307 | + let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); |
| 308 | + let nm = oh_enc.transform(&x).unwrap(); |
| 309 | + assert_eq!(nm, expected_x); |
| 310 | + } |
| 311 | + |
| 312 | + #[test] |
| 313 | + fn fail_on_bad_category() { |
| 314 | + let m = DenseMatrix::from_2d_array(&[ |
| 315 | + &[1.0, 1.5, 3.0], |
| 316 | + &[2.0, 1.5, 4.0], |
| 317 | + &[1.0, 1.5, 5.0], |
| 318 | + &[2.0, 1.5, 6.0], |
| 319 | + ]); |
| 320 | + |
| 321 | + let params = OneHotEncoderParams::from_cat_idx(&[1]); |
| 322 | + match OneHotEncoder::fit(&m, params) { |
| 323 | + Err(_) => { |
| 324 | + assert!(true); |
| 325 | + } |
| 326 | + _ => assert!(false), |
| 327 | + } |
| 328 | + } |
| 329 | +} |
0 commit comments