Skip to content

Commit

Permalink
Added follk handling for k-means
Browse files Browse the repository at this point in the history
  • Loading branch information
denysgerasymuk799 committed May 11, 2024
1 parent 620fd7d commit 3d1e40f
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions source/null_imputers/imputation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ def impute_with_kmeans(X_train_with_nulls: pd.DataFrame, X_tests_with_nulls_lst:
numeric_columns_with_nulls: list, categorical_columns_with_nulls: list,
hyperparams: dict, **kwargs):
seed = kwargs['experiment_seed']
dataset_name = kwargs['dataset_name']

X_train_encoded, cat_encoders, categorical_columns_idxs = encode_dataset_for_missforest(X_train_with_nulls)
X_train_encoded, cat_encoders, categorical_columns_idxs = \
encode_dataset_for_missforest(X_train_with_nulls,
dataset_name=dataset_name,
categorical_columns_with_nulls=categorical_columns_with_nulls)

numerical_columns_idxs = get_numerical_columns_indexes(X_train_encoded)

if len(numerical_columns_idxs) == len(numeric_columns_with_nulls):
kmeans_imputer_mode = "kmodes"
else:
Expand All @@ -140,24 +144,21 @@ def impute_with_kmeans(X_train_with_nulls: pd.DataFrame, X_tests_with_nulls_lst:

X_train_repaired_values = kmeans_imputer.fit_transform(X_train_encoded.values.astype(float), cat_vars=categorical_columns_idxs)
X_train_repaired = pd.DataFrame(X_train_repaired_values, columns=X_train_encoded.columns, index=X_train_encoded.index)
X_train_imputed = decode_dataset_for_missforest(X_train_repaired, cat_encoders)

# Set the same columns types as in the original dataset
# X_train_imputed[categorical_columns_with_nulls] = X_train_imputed[categorical_columns_with_nulls].astype(int).astype('str')
X_train_imputed = decode_dataset_for_missforest(X_train_repaired, cat_encoders, dataset_name=dataset_name)

X_tests_imputed_lst = []
for i in range(len(X_tests_with_nulls_lst)):
X_test_with_nulls = X_tests_with_nulls_lst[i]

X_test_encoded, _, _ = encode_dataset_for_missforest(X_test_with_nulls, cat_encoders=cat_encoders)

X_test_encoded, _, _ = encode_dataset_for_missforest(X_test_with_nulls,
cat_encoders=cat_encoders,
dataset_name=dataset_name,
categorical_columns_with_nulls=categorical_columns_with_nulls)
X_test_repaired_values = kmeans_imputer.transform(X_test_encoded.values.astype(float))
X_test_repaired = pd.DataFrame(X_test_repaired_values, columns=X_test_encoded.columns, index=X_test_encoded.index)
X_test_imputed = decode_dataset_for_missforest(X_test_repaired, cat_encoders)
X_test_imputed = decode_dataset_for_missforest(X_test_repaired, cat_encoders, dataset_name=dataset_name)

X_tests_imputed_lst.append(X_test_imputed)

print(f"X_test_imputed_list length: {len(X_tests_imputed_lst)}")

null_imp_params_dct = {col: kmeans_imputer.get_predictors_params() for col in X_train_with_nulls.columns}
null_imp_params_dct = {col: kmeans_imputer.get_predictors_params() for col in X_train_with_nulls.columns}
return X_train_imputed, X_tests_imputed_lst, null_imp_params_dct

0 comments on commit 3d1e40f

Please sign in to comment.