Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
TECH-yufu committed Dec 14, 2022
2 parents 8f9737d + fa78316 commit 60e6cca
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 272 deletions.
8 changes: 6 additions & 2 deletions dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,18 @@ def dataset_info_restructure(dataset_name, data):
inv_var_dtype[x] = k
# finding number of values per variable
var_info = {}

offset = 0
for idx, variable_name in enumerate(list(data.columns)):
if inv_var_dtype[idx] == 'categorical':
new_columns = pd.get_dummies(data[variable_name])
new_columns_names = list(variable_name + '_' + new_columns.columns.astype('str'))
data[new_columns_names] = new_columns
for i, name in enumerate(new_columns_names):
data.insert(loc=idx+i+1+offset, column=name, value=new_columns.iloc[:,i])
# data[new_columns_names] = new_columns
num_unique = len(new_columns_names) # num unique values
# dropping original dataframe
offset += num_unique
offset -= 1
data.drop(columns=variable_name, inplace=True)
var_info[idx] = {'name': variable_name, 'dtype': 'categorical', 'num_vals': num_unique}

Expand Down
15 changes: 5 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import argparse
from logger import Logger
import numpy as np
Expand Down Expand Up @@ -114,13 +115,7 @@
# Save and plot test_results
# loading best model
model = load_model(model_path=result_dir, model=model, device=device)
# test_loss = evaluation(model, test_loader, device)
# todo: extend outputs
# NLL, MSE, imputation_error = evaluate_to_table(model, test_loader, device)
get_test_results(model=model, result_path=result_dir, model_name=name, test_loader=test_loader, var_info=var_info, device=device)

# evaluate_to_table(test_loader, var_info, name=None, model_best=None, epoch=None, M=256, natural=False,
# device=None)



model.eval()
imputation_ratio=0.5 # todo extend to arguments?
results_df = get_test_results(model=model, test_loader=test_loader, var_info=var_info, device=device, imputation_ratio=imputation_ratio)
print(results_df)
211 changes: 61 additions & 150 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
import torch.nn as nn
import torch.nn.functional as F
from prob_dists import *

from pytorch_model_summary import summary

# importing distributions
import torch.distributions as dists

Expand Down Expand Up @@ -104,7 +99,12 @@ def decode(self, z):
# input are latent variables, 2*L (mean and variance)
# the output depends on expected output distribution, see below.
h_d = self.decoder(z) # node: 'decode' and 'decoder' to minimize confusion
prob_d = torch.zeros(h_d.shape)

if self.natural:
prob_d = h_d.clone()
else:
prob_d = torch.zeros(h_d.shape)

# hidden output of decoder
idx = 0
# decoder outputs the distribution parameters (e.g. mu, sigma, eta's)
Expand All @@ -128,19 +128,14 @@ def decode(self, z):
else:
# eta2 have to be negative -inf<eta2<0
# extracting eta2
eta2 = h_d[:, idx+1:idx + num_vals]
# todo is this correct?
h_d[:, idx + 1:idx + num_vals] = -self.softplus(eta2)
prob_d[:, idx+1:idx + num_vals] = -self.softplus(h_d[:, idx+1:idx + num_vals])
idx += num_vals
else:
raise ValueError('Either `categorical` or `gaussian`')
# returning output
if not self.natural:
# returning probability distribution
return prob_d
else:
# returning the etas
return h_d

# returning probability distribution or returning the etas
return prob_d


def sample(self, z):
# TODO: cannot do flatten if z is batched
Expand Down Expand Up @@ -218,13 +213,12 @@ def log_prob(self, x, z):
else:
probs = prob_d[:, prob_d_idx:prob_d_idx + num_vals]

log_p[:, var] = log_categorical(x[:, x_idx:x_idx + 1], probs, num_classes=num_vals, reduction='sum',
dim=-1).sum(-1)
log_p[:, var] = log_categorical(x[:, prob_d_idx:prob_d_idx + num_vals], probs, reduction='sum', dim=1)#.sum(-1)

prob_d_idx += num_vals

elif self.var_info[var]['dtype'] == 'numerical': # Gaussian
num_vals = self.var_info[var]['num_vals']
num_vals = self.var_info[var]['num_vals'] # always 2

if self.natural:
# -*softplus has been applied to softplus
Expand All @@ -233,16 +227,18 @@ def log_prob(self, x, z):
mu, log_var = -0.5 * eta1 / eta2, torch.log(-0.5 / eta2)
else:
mu, log_var = torch.chunk(prob_d[:, prob_d_idx:prob_d_idx + num_vals], 2, dim=1)
log_p[:, var] = log_normal(x[:, x_idx:x_idx + 1], mu, log_var, reduction='sum', dim=-1).sum(-1)
prob_d_idx += num_vals
log_p[:, var] = log_normal(x[:, prob_d_idx:prob_d_idx + num_vals], mu, log_var, reduction='sum', dim=-1) #.sum(-1)
prob_d_idx += 1

elif self.var_info[var]['dtype'] == 'bernoulli':
log_p = log_bernoulli(x, prob_d, reduction='sum', dim=-1)

else:
raise ValueError('Either `gaussian`, `categorical`, or `bernoulli`')
# summing log_p for each variable (i.e. we have on log prob per batch)
# - e.g. all categories in categorical, i.e. dimension is batch
return torch.sum(log_p, axis=1).to(self.device)

return log_p.sum(axis=1).to(self.device) # summing all log_probs

def forward(self, z, x=None, type='log_prob'):
assert type in ['decoder', 'log_prob'], 'Type could be either decode or log_prob'
Expand Down Expand Up @@ -296,142 +292,57 @@ def __init__(self, total_num_vals, L, var_info, D, M, natural, device):

self.var_info = var_info # contains type of likelihood for variables

self.D = D

self.device = device

def forward(self, x, reduction='sum'):
def forward(self, x, loss=True, reconstruct=False, nll=False, reduction='sum'):

#if reduction == 'sum':
# metric = SumMetric()
#elif reduction == 'avg':
# metric = MeanMetric()
#else:
# raise NotImplementedError('unknown recuction')


# Encode
mu_e, log_var_e = self.encoder.encode(x)
# sample in latent space
z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
z = z.to(self.device)
# Sample/predict
output = self.decoder.sample(z)

# x_params = [head(y_shared, s_samples) for head in self.heads]

# Loss
# ELBO
# reconstruction error
RE = self.decoder.log_prob(x, z) # z is decoded back
# Kullback–Leibler divergence, regularizer
KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu_e=mu_e, log_var_e=log_var_e, z=z)).sum(-1)
# loss
nll = (RE / self.total_num_vals).mean().detach()
rmse = self.RMSE(x, output)
# returning the output, the loss and
if reduction == 'sum':
return {'output': output}, {'loss': -(RE + KL).sum()}, {'NLL': nll, 'RMSE': rmse}
# meaning
elif reduction == 'avg':
return {'output': output}, {'loss': -(RE + KL).mean()}, {'NLL': nll, 'RMSE': rmse}

def RMSE(self, x, x_recon):
var_idx = 0
MSE = 0
RMSE = 0
D = len(self.var_info.keys()) # Num variables
obs_in_batch = x.shape[0] # Num observations in batch
for var in self.var_info.keys():
num_vals = self.var_info[var]['num_vals']

# Getting length of slice
if self.var_info[var]['dtype'] == 'numerical':
idx_slice = 1
else: # categorical
idx_slice = num_vals

# Imputation targets and predictions - for variable
var_targets = x[:, var_idx:var_idx + idx_slice]
var_preds = x_recon[:, var_idx:var_idx + idx_slice]

# MSE per variable
MSE_var = torch.sum((var_targets - var_preds) ** 2) / obs_in_batch

# Summing variable MSEs - (outer-most sum of formula)
MSE += MSE_var

# Updating current variable index
if self.var_info[var]['dtype'] == 'numerical':
var_idx += 1
else: # categorical
var_idx += num_vals

# Taking square-root (RMSE), and averaging over features. (As seen in formula)
RMSE += torch.sqrt(MSE) / D

return RMSE

""" def RMSE(self, ):
def RMSE(test_loader, var_info, model):
model.eval()
RMSE = 0 # Initializing RMSE score
# Num variables
D = len(var_info.keys())
# Getting the reconstructed test_batch by sending the imputed test batch through VAE
x_recon = model.forward(x)[0]['output'].detach().numpy()
var_idx = 0
MSE = 0
for var in var_info.keys():
num_vals = var_info[var]['num_vals']
# Getting length of slice
if var_info[var]['dtype'] == 'numerical':
idx_slice = 1
else: # categorical
idx_slice = num_vals
# MSE per variable - for all unobserved slots (inner-most sum of formula)
MSE_var = torch.mean((x - x_recon) ** 2)
# Summing variable MSEs - (outer-most sum of formula)
MSE += MSE_var
# Updating current variable index
if var_info[var]['dtype'] == 'numerical':
var_idx += 1
else: # categorical
var_idx += num_vals
# Taking square-root (RMSE), and averaging over features. (As seen in formula)
RMSE += torch.sqrt(MSE) / D
# Getting average RMSE across batches
total_batches = indx_batch + 1
return RMSE / total_batches
def nLLloss(self, x, y_true):
mu_e, log_var_e = self.encoder.encode(x)
z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
z = z.to(self.device)
RE = self.decoder.log_prob(x, z)
y_pred = RE
loss = nn.NLLLoss()
nllloss = loss(y_pred, y_true)
return nllloss
def mseloss(self, x, y_true):
mu_e, log_var_e = self.encoder.encode(x)
z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
z = z.to(self.device)
RE = self.decoder.log_prob(x, z)
y_pred = RE
loss = nn.MSELoss()
nllloss = loss(y_pred, y_true)
return nllloss
def imputation_error(self, x, y_true):
return self.nLLloss(x, y_true)
def sample(self, batch_size=64):
z = self.prior.sample(batch_size=batch_size)
return self.decoder.sample(z)
"""
# reconstruct
RECONSTRUCTION = None
if reconstruct:
# Sample/predict
RECONSTRUCTION = self.decoder.sample(z)
# updated
#reconstruction_dict = {'reconstruction': output}

LOSS = None
if loss:
# ELBO
# reconstruction error
RE = self.decoder.log_prob(x, z) # z is decoded back
# Kullback–Leibler divergence, regularizer
KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu_e=mu_e, log_var_e=log_var_e, z=z)).sum(-1)
# summing the loss for this batch
LOSS = -(RE + KL).sum()
#loss_dict = {'loss': -(RE + KL).sum()}

NLL = None
if nll:
assert (nll and loss) == True, 'loss also has to be true in input for forward call'
# loss
# first find NLL averaged over variables in data
# then mean the batch
# updated
NLL = (RE / self.D).mean().detach()
#nll_dict = {'NLL': (RE / self.D).mean().detach()} #, 'RMSE': rmse}


return RECONSTRUCTION, LOSS, NLL


class Baseline():
Expand Down
9 changes: 6 additions & 3 deletions prob_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
def log_categorical(x, p, reduction='sum', dim=None):
#x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
#log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
# using clamp to avoid log(0) by setting min and max values as EPS and 1-EPS
log_p = x * torch.log(torch.clamp(p, EPS, 1. - EPS))

if reduction == 'avg':
return torch.mean(log_p, dim)
elif reduction == 'sum':
Expand Down
19 changes: 1 addition & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,14 @@ def training(logger, save_path, max_patience, num_epochs, model, optimizer, trai
if model.dequantization:
batch = batch + torch.rand(batch.shape)
batch = batch.to(device)


# batch[0] -> numerical data
# batch[1] -> categorical data

#numerical = batch[0].float()
#categorical = batch[1]

# concatenate into big input
#batch = torch.cat((numerical, categorical), dim=1)

# Should be different model for each kind
#batch = torch.stack(batch[1]).float() # TODO: To access only one (categorical )attribute - only needed as long as no multi-head
# batch = batch[0]
# model returns the loss in forward
_, loss, _ = model.forward(batch)
# TODO: this was also implemented in utils.evaluation and utils.samples_generated
loss = loss['loss']
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()

# Validation
#loss_val = evaluation(val_loader, var_info, model=model, model_best=model, epoch=e,natural=natural,device=device)
loss_val, performance_df = evaluation(model=model, data_loader=val_loader, device=device)
loss_val = evaluation(model=model, data_loader=val_loader, device=device)
logger.write_to_board(name="Validation", scalars={"NLL": loss_val}, index=e)
print(f'Epoch: {e}, loss val={loss_val}')
nll_val.append(loss_val.detach()) # save for plotting
Expand Down
Loading

0 comments on commit 60e6cca

Please sign in to comment.