@@ -150,6 +150,12 @@ def _fit(
150
150
all_nan_columns = X .columns [X .isna ().all ()]
151
151
for col in all_nan_columns :
152
152
X [col ] = pd .to_numeric (X [col ])
153
+
154
+ # Handle objects if possible
155
+ exist_object_columns = has_object_columns (X .dtypes .values )
156
+ if exist_object_columns :
157
+ X = self .infer_objects (X )
158
+
153
159
self .dtypes = [dt .name for dt in X .dtypes ] # Also note this change in self.dtypes
154
160
self .all_nan_columns = set (all_nan_columns )
155
161
@@ -260,20 +266,22 @@ def transform(
260
266
261
267
if hasattr (X , "iloc" ) and not scipy .sparse .issparse (X ):
262
268
X = cast (Type [pd .DataFrame ], X )
263
- if self .all_nan_columns is not None :
264
- for column in X .columns :
265
- if column in self .all_nan_columns :
266
- if not X [column ].isna ().all ():
267
- X [column ] = np .nan
268
- X [column ] = pd .to_numeric (X [column ])
269
- if len (self .categorical_columns ) > 0 :
270
- if self .column_transformer is None :
271
- raise AttributeError ("Expect column transformer to be built"
272
- "if there are categorical columns" )
273
- categorical_columns = self .column_transformer .transformers_ [0 ][- 1 ]
274
- for column in categorical_columns :
275
- if X [column ].isna ().all ():
276
- X [column ] = X [column ].astype ('object' )
269
+
270
+ if self .all_nan_columns is None :
271
+ raise ValueError ('_fit must be called before calling transform' )
272
+
273
+ for col in list (self .all_nan_columns ):
274
+ X [col ] = np .nan
275
+ X [col ] = pd .to_numeric (X [col ])
276
+
277
+ if len (self .categorical_columns ) > 0 :
278
+ if self .column_transformer is None :
279
+ raise AttributeError ("Expect column transformer to be built"
280
+ "if there are categorical columns" )
281
+ categorical_columns = self .column_transformer .transformers_ [0 ][- 1 ]
282
+ for column in categorical_columns :
283
+ if X [column ].isna ().all ():
284
+ X [column ] = X [column ].astype ('object' )
277
285
278
286
# Check the data here so we catch problems on new test data
279
287
self ._check_data (X )
@@ -366,10 +374,10 @@ def _check_data(
366
374
self .column_order = column_order
367
375
368
376
dtypes = [dtype .name for dtype in X .dtypes ]
369
-
377
+ diff_cols = X . columns [[ s_dtype != dtype for s_dtype , dtype in zip ( self . dtypes , dtypes )]]
370
378
if len (self .dtypes ) == 0 :
371
379
self .dtypes = dtypes
372
- elif self .dtypes != dtypes :
380
+ elif not self ._is_datasets_consistent ( diff_cols , X ) :
373
381
raise ValueError ("The dtype of the features must not be changed after fit(), but"
374
382
" the dtypes of some columns are different between training ({}) and"
375
383
" test ({}) datasets." .format (self .dtypes , dtypes ))
@@ -517,11 +525,17 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
517
525
self .logger .warning (f'Casting the column { key } to { dtype } caused the exception { e } ' )
518
526
pass
519
527
else :
520
- # Calling for the first time to infer the categories
521
- X = X .infer_objects ()
522
- for column , data_type in zip (X .columns , X .dtypes ):
523
- if not is_numeric_dtype (data_type ):
524
- X [column ] = X [column ].astype ('category' )
528
+ if len (self .dtypes ) != 0 :
529
+ # when train data has no object dtype, but test does
530
+ # we prioritise the datatype given in training data
531
+ for column , data_type in zip (X .columns , self .dtypes ):
532
+ X [column ] = X [column ].astype (data_type )
533
+ else :
534
+ # Calling for the first time to infer the categories
535
+ X = X .infer_objects ()
536
+ for column , data_type in zip (X .columns , X .dtypes ):
537
+ if not is_numeric_dtype (data_type ):
538
+ X [column ] = X [column ].astype ('category' )
525
539
526
540
# only numerical attributes and categories
527
541
self .object_dtype_mapping = {column : data_type for column , data_type in zip (X .columns , X .dtypes )}
0 commit comments