Skip to content

Commit 828df4e

Browse files
committed
Use CategoryMapper to transform an iterator. No more passing iterator to SeriesEncoders
1 parent 374dfec commit 828df4e

File tree

1 file changed

+33
-34
lines changed

1 file changed

+33
-34
lines changed

src/preprocessing/categorical_encoder.rs

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
//! # One-hot Encoding For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
22
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by replacing all categorical variables with their one-hot equivalents
33
//!
4-
//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [SeriesOneHotEncoder](../series_encoder/struct.SeriesOneHotEncoder.html)
4+
//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [CategoryMapper](../series_encoder/struct.CategoryMapper.html)
55
//!
66
//! ### Usage Example
77
//! ```
88
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
9-
//! use smartcore::preprocessing::categorical_encoder::{OneHotEnc, OneHotEncoderParams};
9+
//! use smartcore::preprocessing::categorical_encoder::{OneHotEncoder, OneHotEncoderParams};
1010
//! let data = DenseMatrix::from_2d_array(&[
1111
//! &[1.5, 1.0, 1.5, 3.0],
1212
//! &[1.5, 2.0, 1.5, 4.0],
@@ -15,7 +15,7 @@
1515
//! ]);
1616
//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
1717
//! // Infer number of categories from data and return a reusable encoder
18-
//! let encoder = OneHotEnc::fit(&data, encoder_params).unwrap();
18+
//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap();
1919
//! // Transform categorical to one-hot encoded (can transform similar)
2020
//! let oh_data = encoder.transform(&data).unwrap();
2121
//! // Produces the following:
@@ -30,7 +30,7 @@ use crate::error::Failed;
3030
use crate::linalg::Matrix;
3131

3232
use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable};
33-
use crate::preprocessing::series_encoder::{SeriesOneHotEncoder, SeriesEncoder};
33+
use crate::preprocessing::series_encoder::CategoryMapper;
3434

3535
/// OneHotEncoder Parameters
3636
#[derive(Debug, Clone)]
@@ -97,17 +97,18 @@ fn validate_col_is_categorical<T: Categorizable>(data: &[T]) -> bool {
9797

9898
/// Encode Categorical variavbles of data matrix to one-hot
9999
#[derive(Debug, Clone)]
100-
pub struct OneHotEncoder<E> {
101-
series_encoders: Vec<E>,
100+
pub struct OneHotEncoder {
101+
category_mappers: Vec<CategoryMapper<CategoricalFloat>>,
102102
col_idx_categorical: Vec<usize>,
103103
}
104104

105-
impl<E: SeriesEncoder<CategoricalFloat>> OneHotEncoder<E> {
105+
impl OneHotEncoder {
106106
/// Create an encoder instance with categories infered from data matrix
107-
pub fn fit<T: Categorizable, M: Matrix<T>>(
108-
data: &M,
109-
params: OneHotEncoderParams,
110-
) -> Result<OneHotEncoder<E>, Failed> {
107+
pub fn fit<T, M>(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed>
108+
where
109+
T: Categorizable,
110+
M: Matrix<T>,
111+
{
111112
match (params.col_idx_categorical, params.infer_categorical) {
112113
(None, false) => Err(Failed::fit(
113114
"Must pass categorical series ids or infer flag",
@@ -126,8 +127,7 @@ impl<E: SeriesEncoder<CategoricalFloat>> OneHotEncoder<E> {
126127
// col buffer to avoid allocations
127128
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
128129

129-
let mut res: Vec<E> =
130-
Vec::with_capacity(idxs.len());
130+
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
131131

132132
for &idx in &idxs {
133133
data.copy_col_as_vec(idx, &mut col_buf);
@@ -139,11 +139,11 @@ impl<E: SeriesEncoder<CategoricalFloat>> OneHotEncoder<E> {
139139
return Err(Failed::fit(&msg[..]));
140140
}
141141
let hashable_col = col_buf.iter().map(|v| v.to_category());
142-
res.push(E::fit_to_iter(hashable_col));
142+
res.push(CategoryMapper::fit_to_iter(hashable_col));
143143
}
144144

145145
Ok(Self {
146-
series_encoders: res, //Self::build_series_encoders::<T, M>(data, &idxs[..]),
146+
category_mappers: res,
147147
col_idx_categorical: idxs,
148148
})
149149
}
@@ -155,10 +155,14 @@ impl<E: SeriesEncoder<CategoricalFloat>> OneHotEncoder<E> {
155155
}
156156

157157
/// Transform categorical variables to one-hot encoded and return a new matrix
158-
pub fn transform<T: Categorizable, M: Matrix<T>>(&self, x: &M) -> Result<M, Failed> {
158+
pub fn transform<T, M>(&self, x: &M) -> Result<M, Failed>
159+
where
160+
T: Categorizable,
161+
M: Matrix<T>,
162+
{
159163
let (nrows, p) = x.shape();
160164
let additional_params: Vec<usize> = self
161-
.series_encoders
165+
.category_mappers
162166
.iter()
163167
.map(|enc| enc.num_categories())
164168
.collect();
@@ -172,10 +176,10 @@ impl<E: SeriesEncoder<CategoricalFloat>> OneHotEncoder<E> {
172176
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
173177
let cidx = new_col_idx[old_cidx];
174178
let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category());
175-
let sencoder = &self.series_encoders[pidx];
176-
let oh_series: Vec<Option<Vec<T>>> = sencoder.transform_iter(col_iter);
179+
let sencoder = &self.category_mappers[pidx];
180+
let oh_series = col_iter.map(|c| sencoder.get_one_hot::<T, Vec<T>>(&c));
177181

178-
for (row, oh_vec) in oh_series.iter().enumerate() {
182+
for (row, oh_vec) in oh_series.enumerate() {
179183
match oh_vec {
180184
None => {
181185
// Since we support T types, bad value in a series causes in to be invalid
@@ -215,16 +219,11 @@ impl<E: SeriesEncoder<CategoricalFloat>> OneHotEncoder<E> {
215219
}
216220
}
217221

218-
/// Convinince type for common use
219-
pub type OneHotEnc = OneHotEncoder<SeriesOneHotEncoder<CategoricalFloat>>;
220-
221-
222222
#[cfg(test)]
223223
mod tests {
224224
use super::*;
225225
use crate::linalg::naive::dense_matrix::DenseMatrix;
226-
use crate::preprocessing::series_encoder::SeriesOneHotEncoder;
227-
226+
use crate::preprocessing::series_encoder::CategoryMapper;
228227

229228
#[test]
230229
fn adjust_idxs() {
@@ -275,20 +274,20 @@ mod tests {
275274
let series = vec![3.0, 1.0, 2.0, 1.0];
276275
let hashable_series: Vec<CategoricalFloat> =
277276
series.iter().map(|v| v.to_category()).collect();
278-
let enc = SeriesOneHotEncoder::from_positional_category_vec(hashable_series);
279-
let inv = enc.invert_one(vec![0.0, 0.0, 1.0]);
277+
let enc = CategoryMapper::from_positional_category_vec(hashable_series);
278+
let inv = enc.invert_one_hot(vec![0.0, 0.0, 1.0]);
280279
let orig_val: f64 = inv.unwrap().into();
281280
assert_eq!(orig_val, 2.0);
282281
}
283282
#[test]
284283
fn test_fit() {
285284
let (x, _) = build_fake_matrix();
286285
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
287-
let oh_enc = OneHotEnc::fit(&x, params).unwrap();
288-
assert_eq!(oh_enc.series_encoders.len(), 2);
286+
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
287+
assert_eq!(oh_enc.category_mappers.len(), 2);
289288

290289
let num_cat: Vec<usize> = oh_enc
291-
.series_encoders
290+
.category_mappers
292291
.iter()
293292
.map(|a| a.num_categories())
294293
.collect();
@@ -299,13 +298,13 @@ mod tests {
299298
fn matrix_transform_test() {
300299
let (x, expected_x) = build_fake_matrix();
301300
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
302-
let oh_enc = OneHotEnc::fit(&x, params).unwrap();
301+
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
303302
let nm = oh_enc.transform(&x).unwrap();
304303
assert_eq!(nm, expected_x);
305304

306305
let (x, expected_x) = build_cat_first_and_last();
307306
let params = OneHotEncoderParams::from_cat_idx(&[0, 2]);
308-
let oh_enc = OneHotEnc::fit(&x, params).unwrap();
307+
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
309308
let nm = oh_enc.transform(&x).unwrap();
310309
assert_eq!(nm, expected_x);
311310
}
@@ -320,7 +319,7 @@ mod tests {
320319
]);
321320

322321
let params = OneHotEncoderParams::from_cat_idx(&[1]);
323-
match OneHotEnc::fit(&m, params) {
322+
match OneHotEncoder::fit(&m, params) {
324323
Err(_) => {
325324
assert!(true);
326325
}

0 commit comments

Comments
 (0)