Skip to content

Commit

Permalink
Merge pull request #262 from WenjieDu/make_csdi_val_process_original
Browse files Browse the repository at this point in the history
Making CSDI val process same as the original
  • Loading branch information
WenjieDu authored Dec 8, 2023
2 parents d504e6b + 1a41f34 commit 01ddf3b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 39 deletions.
44 changes: 22 additions & 22 deletions pypots/imputation/csdi/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,6 @@ def __init__(
rate: float = 0.1,
):
super().__init__(data, return_labels, file_type)
self.time_points = (
None if "time_points" not in data.keys() else data["time_points"]
)
# _, self.time_points = self._check_input(self.X, time_points)
self.for_pattern_mask = (
None if "for_pattern_mask" not in data.keys() else data["for_pattern_mask"]
)
# _, self.for_pattern_mask = self._check_input(self.X, for_pattern_mask)
self.cut_length = (
None if "cut_length" not in data.keys() else data["cut_length"]
)
# _, self.cut_length = self._check_input(self.X, cut_length)
self.rate = rate

def _fetch_data_from_array(self, idx: int) -> Iterable:
Expand Down Expand Up @@ -71,19 +59,21 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:

observed_data = X_intact
observed_mask = missing_mask + indicating_mask
gt_mask = missing_mask
observed_tp = (
torch.arange(0, self.n_steps, dtype=torch.float32)
if self.time_points is None
else self.time_points[idx].to(torch.float32)
if "time_points" not in self.data.keys()
else torch.from_numpy(self.data["time_points"][idx]).to(torch.float32)
)
gt_mask = missing_mask
for_pattern_mask = (
gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
gt_mask
if "for_pattern_mask" not in self.data.keys()
else torch.from_numpy(self.data["for_pattern_mask"][idx]).to(torch.float32)
)
cut_length = (
torch.zeros(len(observed_data)).long()
if self.cut_length is None
else self.cut_length[idx]
if "cut_length" not in self.data.keys()
else torch.from_numpy(self.data["cut_length"][idx]).to(torch.float32)
)

sample = [
Expand Down Expand Up @@ -124,15 +114,25 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:

observed_data = X_intact
observed_mask = missing_mask + indicating_mask
observed_tp = self.time_points[idx].to(torch.float32)
gt_mask = indicating_mask
observed_tp = (
torch.arange(0, self.n_steps, dtype=torch.float32)
if "time_points" not in self.file_handle.keys()
else torch.from_numpy(self.file_handle["time_points"][idx]).to(
torch.float32
)
)
for_pattern_mask = (
gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
gt_mask
if "for_pattern_mask" not in self.file_handle.keys()
else torch.from_numpy(self.file_handle["for_pattern_mask"][idx]).to(
torch.float32
)
)
cut_length = (
torch.zeros(len(observed_data)).long()
if self.cut_length is None
else self.cut_length[idx]
if "cut_length" not in self.file_handle.keys()
else torch.from_numpy(self.file_handle["cut_length"][idx]).to(torch.float32)
)

sample = [
Expand Down
22 changes: 7 additions & 15 deletions pypots/imputation/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
from ...utils.metrics import calc_mse


class CSDI(BaseNNImputer):
Expand Down Expand Up @@ -245,23 +244,16 @@ def _train_model(

if val_loader is not None:
self.model.eval()
imputation_collector = []
val_loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs, training=False)
imputed_data = results["imputed_data"].mean(axis=1)
imputation_collector.append(imputed_data)

imputation_collector = torch.cat(imputation_collector)
imputation_collector = imputation_collector.cpu().detach().numpy()

mean_val_loss = calc_mse(
imputation_collector,
val_loader.dataset.data["X_intact"],
val_loader.dataset.data["indicating_mask"],
# the above val_loader.dataset.data is a dict containing the validation dataset
)
results = self.model.forward(
inputs, training=False, n_sampling_times=0
)
val_loss_collector.append(results["loss"].item())

mean_val_loss = np.asarray(val_loss_collector).mean()

# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/csdi/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def forward(self, inputs, training=True, n_sampling_times=1):
loss = loss_func(observed_data, cond_mask, observed_mask, side_info, training)
results = {"loss": loss}

if not training:
if not training and n_sampling_times > 0:
samples = self.impute(
observed_data, cond_mask, side_info, n_sampling_times
) # (bz,n_sampling,K,L)
Expand All @@ -269,6 +269,6 @@ def forward(self, inputs, training=True, n_sampling_times=1):

results["imputed_data"] = imputed_data.permute(
0, 1, 3, 2
) # (bz,n_sampling,K,L)
) # (bz,n_sampling,L,K)

return results

0 comments on commit 01ddf3b

Please sign in to comment.