@@ -162,9 +162,13 @@ def _fit(
162
162
# with nan values.
163
163
# Columns that are completely made of NaN values are provided to the pipeline
164
164
# so that later stages decide how to handle them
165
+
166
+ # Clear whatever null column markers we had previously
167
+ self .null_columns .clear ()
165
168
if np .any (pd .isnull (X )):
166
169
for column in X .columns :
167
170
if X [column ].isna ().all ():
171
+ self .null_columns .add (column )
168
172
X [column ] = pd .to_numeric (X [column ])
169
173
# Also note this change in self.dtypes
170
174
if len (self .dtypes ) != 0 :
@@ -244,30 +248,38 @@ def transform(
244
248
if isinstance (X , np .ndarray ):
245
249
X = self .numpy_array_to_pandas (X )
246
250
247
- if ispandas (X ) and not issparse (X ):
248
- if np .any (pd .isnull (X )):
249
- for column in X .columns :
250
- if X [column ].isna ().all ():
251
- X [column ] = pd .to_numeric (X [column ])
251
+ if hasattr (X , "iloc" ) and not issparse (X ):
252
+ X = cast (pd .DataFrame , X )
253
+ # If we had null columns in our fit call and we made them numeric, then:
254
+ # - If the columns are null even in transform, apply the same procedure.
255
+ # - Otherwise, substitute the values with np.NaN and then make the columns numeric.
256
+ # If the column is null here, but it was not in fit, it does not matter.
257
+ for column in self .null_columns :
258
+ # The column is not null, make it null since it was null in fit.
259
+ if not X [column ].isna ().all ():
260
+ X [column ] = np .NaN
261
+ X [column ] = pd .to_numeric (X [column ])
262
+
263
+ # for the test set, if we have columns with only null values
264
+ # they will probably have a numeric type. If these columns were not
265
+ # with only null values in the train set, they should be converted
266
+ # to the type that they had during fitting.
267
+ for column in X .columns :
268
+ if X [column ].isna ().all ():
269
+ X [column ] = X [column ].astype (self .dtypes [list (X .columns ).index (column )])
252
270
253
271
# Also remove the object dtype for new data
254
272
if not X .select_dtypes (include = 'object' ).empty :
255
273
X = self .infer_objects (X )
256
274
257
275
# Check the data here so we catch problems on new test data
258
276
self ._check_data (X )
277
+ # We also need to fillna on the transformation
278
+ # in case test data is provided
279
+ X = self .impute_nan_in_categories (X )
259
280
260
- # Pandas related transformations
261
- if ispandas (X ) and self .column_transformer is not None :
262
- if np .any (pd .isnull (X )):
263
- # After above check it means that if there is a NaN
264
- # the whole column must be NaN
265
- # Make sure it is numerical and let the pipeline handle it
266
- for column in X .columns :
267
- if X [column ].isna ().all ():
268
- X [column ] = pd .to_numeric (X [column ])
269
-
270
- X = self .column_transformer .transform (X )
281
+ if self .encoder is not None :
282
+ X = self .encoder .transform (X )
271
283
272
284
# Sparse related transformations
273
285
# Not all sparse format support index sorting
@@ -557,7 +569,7 @@ def numpy_array_to_pandas(
557
569
Returns:
558
570
pd.DataFrame
559
571
"""
560
- return pd .DataFrame (X ).infer_objects (). convert_dtypes ()
572
+ return pd .DataFrame (X ).convert_dtypes ()
561
573
562
574
def infer_objects (self , X : pd .DataFrame ) -> pd .DataFrame :
563
575
"""
@@ -575,18 +587,13 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
575
587
if hasattr (self , 'object_dtype_mapping' ):
576
588
# Mypy does not process the has attr. This dict is defined below
577
589
for key , dtype in self .object_dtype_mapping .items (): # type: ignore[has-type]
578
- if 'int' in dtype .name :
579
- # In the case train data was interpreted as int
580
- # and test data was interpreted as float, because of 0.0
581
- # for example, honor training data
582
- X [key ] = X [key ].applymap (np .int64 )
583
- else :
584
- try :
585
- X [key ] = X [key ].astype (dtype .name )
586
- except Exception as e :
587
- # Try inference if possible
588
- self .logger .warning (f"Tried to cast column { key } to { dtype } caused { e } " )
589
- pass
590
+ # honor the training data types
591
+ try :
592
+ X [key ] = X [key ].astype (dtype .name )
593
+ except Exception as e :
594
+ # Try inference if possible
595
+ self .logger .warning (f"Tried to cast column { key } to { dtype } caused { e } " )
596
+ pass
590
597
else :
591
598
X = X .infer_objects ()
592
599
for column in X .columns :
0 commit comments