Skip to content

Commit

Permalink
Save TDM before implementing its new version
Browse files Browse the repository at this point in the history
  • Loading branch information
denysgerasymuk799 committed Dec 26, 2024
1 parent 024aabb commit 7c12824
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 0 additions & 4 deletions source/null_imputers/imputation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,6 @@ def impute_with_tdm(X_train_with_nulls: pd.DataFrame, X_tests_with_nulls_lst: li
numeric_columns_with_nulls: list, categorical_columns_with_nulls: list,
hyperparams: dict, **kwargs):
dataset_name = kwargs['dataset_name']
seed = kwargs['experiment_seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

X_train_encoded, cat_encoders, _ = encode_dataset_for_missforest(df=X_train_with_nulls,
dataset_name=dataset_name,
Expand Down
9 changes: 7 additions & 2 deletions source/null_imputers/tdm_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def fit(self, X_train, verbose=True, report_interval=500):
X_train = X_train.clone()
n, d = X_train.shape

batch_size = self.batchsize
if batch_size > n // 2:
e = int(np.log2(n // 2))
batch_size = 2 ** e

mask = torch.isnan(X_train).double()
torch.autograd.set_detect_anomaly(True)

Expand All @@ -60,8 +65,8 @@ def fit(self, X_train, verbose=True, report_interval=500):
proj_loss = 0

for _ in range(self.n_pairs):
idx1 = np.random.choice(n, self.batchsize, replace=False)
idx2 = np.random.choice(n, self.batchsize, replace=False)
idx1 = np.random.choice(n, batch_size, replace=False)
idx2 = np.random.choice(n, batch_size, replace=False)

X1 = X_filled[idx1]
X2 = X_filled[idx2]
Expand Down

0 comments on commit 7c12824

Please sign in to comment.