Skip to content

Commit

Permalink
Added follk handling for k-means
Browse files Browse the repository at this point in the history
  • Loading branch information
denysgerasymuk799 committed May 11, 2024
1 parent 3d1e40f commit 42dfb0d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cluster/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python ./scripts/evaluate_models.py \

python ./scripts/impute_nulls_with_predictor.py \
--dataset folk \
--null_imputers \[\"miss_forest\"\] \
--null_imputers \[\"k_means_clustering\"\] \
--run_nums \[1\] \
--tune_imputers true \
--save_imputed_datasets true \
Expand Down
17 changes: 9 additions & 8 deletions source/null_imputers/imputation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from source.null_imputers.missforest_imputer import MissForestImputer
from source.null_imputers.kmeans_imputer import KMeansImputer
from source.utils.pipeline_utils import encode_dataset_for_missforest, decode_dataset_for_missforest
from source.utils.dataframe_utils import get_object_columns_indexes, get_numerical_columns_indexes
from source.utils.dataframe_utils import get_numerical_columns_indexes


def impute_with_deletion(X_train_with_nulls: pd.DataFrame, X_tests_with_nulls_lst: list,
Expand Down Expand Up @@ -127,18 +127,19 @@ def impute_with_kmeans(X_train_with_nulls: pd.DataFrame, X_tests_with_nulls_lst:
hyperparams: dict, **kwargs):
seed = kwargs['experiment_seed']
dataset_name = kwargs['dataset_name']

X_train_encoded, cat_encoders, categorical_columns_idxs = \
encode_dataset_for_missforest(X_train_with_nulls,
dataset_name=dataset_name,
categorical_columns_with_nulls=categorical_columns_with_nulls)

numerical_columns_idxs = get_numerical_columns_indexes(X_train_encoded)
# Set an appropriate kmeans_imputer_mode type
numerical_columns_idxs = get_numerical_columns_indexes(X_train_with_nulls)
if len(numerical_columns_idxs) == len(numeric_columns_with_nulls):
kmeans_imputer_mode = "kmodes"
else:
kmeans_imputer_mode = "kprototypes"


X_train_encoded, cat_encoders, categorical_columns_idxs = \
encode_dataset_for_missforest(X_train_with_nulls,
dataset_name=dataset_name,
categorical_columns_with_nulls=categorical_columns_with_nulls)

# Impute numerical columns
kmeans_imputer = KMeansImputer(seed=seed, imputer_mode=kmeans_imputer_mode, hyperparameters=hyperparams)

Expand Down

0 comments on commit 42dfb0d

Please sign in to comment.