Skip to content

Commit

Permalink
CUDA 🥱
Browse files Browse the repository at this point in the history
  • Loading branch information
TECH-yufu committed Nov 29, 2022
1 parent 9d9a288 commit 73d0d39
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 22 deletions.
8 changes: 4 additions & 4 deletions dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer

def load_dataset(dataset_name, batch_size, shuffle, seed):
def load_dataset(dataset_name, batch_size, shuffle, seed, pin_memory):


# TODO: sep=';'??
Expand Down Expand Up @@ -44,9 +44,9 @@ def load_dataset(dataset_name, batch_size, shuffle, seed):
test_data = iterate_data(test_data)


train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=pin_memory)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=pin_memory)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=pin_memory)

return ((var_info, var_dtype), (train_loader, val_loader, test_loader))

Expand Down
20 changes: 14 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

# general
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'])
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])

# model parameters
parser.add_argument('--lr', help='Starting learning rate',default=3e-4, type=float)
parser.add_argument('--batch_size', help='"Batch size"', default=32, type=int)
parser.add_argument('--natural', type=str, default='True', choices=['False', 'True'])
parser.add_argument('--natural', dest='natural', action='store_true')

parser.add_argument('--max_epochs', help='"Number of epochs to train for"', default=500, type=int)
parser.add_argument('--max_patience', help='"If training does not improve for longer than --max_patience epochs, it is stopped"', default=4, type=int)
Expand All @@ -42,13 +42,21 @@
os.mkdir(result_dir)
name = 'vae'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running {}".format(device))

if device == 'cpu':
pin_memory = False
else:
pin_memory = True

torch.manual_seed(args.seed)
np.random.seed(args.seed)


# Loading dataset
# information about variables and dataset loaders
output = load_dataset(dataset_name=args.dataset, batch_size=args.batch_size, shuffle=True, seed=args.seed)
output = load_dataset(dataset_name=args.dataset, batch_size=args.batch_size, shuffle=True, seed=args.seed, pin_memory=pin_memory)
# extracting output
info, loaders = output
(var_info, var_dtype) = info
Expand Down Expand Up @@ -81,15 +89,15 @@


prior = torch.distributions.MultivariateNormal(torch.zeros(L), torch.eye(L))
model = VAE(total_num_vals=total_num_vals, L=L, var_info = var_info, D=D, M=M,natural=args.natural)

model = VAE(total_num_vals=total_num_vals, L=L, var_info = var_info, D=D, M=M,natural=args.natural, device=device)
model = model.to(device)
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=args.lr)

# Training procedure
nll_val = training(name=logger.dir, max_patience=args.max_patience, num_epochs=args.max_epochs, model=model,
optimizer=optimizer,
train_loader=train_loader, val_loader=val_loader,var_info=var_info,natural=args.natural)
train_loader=train_loader, val_loader=val_loader,var_info=var_info,natural=args.natural, device=device)

print(nll_val)

Expand Down
28 changes: 19 additions & 9 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,14 @@ def forward(self, x, type='log_prob'):


class Decoder(nn.Module):
def __init__(self, decoder_net, var_info, total_num_vals=None, natural=True):
def __init__(self, decoder_net, var_info, total_num_vals=None, natural=True, device=None):
super(Decoder, self).__init__()

self.decoder = decoder_net
self.var_info = var_info
self.total_num_vals = total_num_vals # depends on num_classes of each attribute
self.natural = natural
self.device = device
# self.distribution = 'gaussian'

def decode(self, z):
Expand Down Expand Up @@ -184,6 +185,7 @@ def sample(self, z):
def log_prob(self, x, z):
# calculating the log−probability which is later used for ELBO
prob_d = self.decode(z) # probability output
prob_d = prob_d.to(self.device)
log_p = torch.zeros((len(prob_d), len(self.var_info)))
prob_d_idx = 0
for x_idx, var in enumerate(self.var_info):
Expand All @@ -202,17 +204,18 @@ def log_prob(self, x, z):
num_vals = self.var_info[var]['num_vals']

if self.natural:
natural = to_natural(prob_d[:, prob_d_idx:prob_d_idx+num_vals])
log_var = torch.log(torch.var(natural, dim=0))
natural_param = to_natural(prob_d[:, prob_d_idx:prob_d_idx+num_vals])
log_var = natural_param[:,1]
# log_var = torch.log(prob_d)
log_p[:, var] = log_normal_diag(x[:, x_idx:x_idx + 1], natural,
log_p[:, var] = log_normal_diag(x[:, x_idx:x_idx + 1], natural_param[:,0],
log_var, reduction='sum', dim=-1).sum(-1)
prob_d_idx += num_vals
else:
# don't know if reduction is correct
log_var = torch.log(torch.var(prob_d[:, prob_d_idx:prob_d_idx+num_vals], dim=0))
log_var = torch.log(prob_d[:, prob_d_idx:prob_d_idx+num_vals][:,1])
mu = prob_d[:, prob_d_idx:prob_d_idx+num_vals][:,0]
# log_var = torch.log(prob_d)
log_p[:, var] = log_normal_diag(x[:, x_idx:x_idx+1], prob_d[:, prob_d_idx:prob_d_idx+num_vals], log_var, reduction='sum', dim=-1).sum(-1)
log_p[:, var] = log_normal_diag(x[:, x_idx:x_idx+1], mu, log_var, reduction='sum', dim=-1).sum(-1)
prob_d_idx += num_vals

elif self.var_info[var]['dtype'] == 'bernoulli':
Expand All @@ -221,7 +224,7 @@ def log_prob(self, x, z):
else:
raise ValueError('Either `gaussian`, `categorical`, or `bernoulli`')

return log_p.sum(axis=1) # summing all log_probs
return log_p.sum(axis=1).to(self.device) # summing all log_probs

def forward(self, z, x=None, type='log_prob'):
assert type in ['decoder', 'log_prob'], 'Type could be either decode or log_prob'
Expand All @@ -245,7 +248,7 @@ def log_prob(self, z):


class VAE(nn.Module):
def __init__(self, total_num_vals, L, var_info,D,M,natural):
def __init__(self, total_num_vals, L, var_info,D,M,natural, device):
super().__init__()

encoder_net = nn.Sequential(nn.Linear(D, M), nn.LeakyReLU(),
Expand All @@ -255,10 +258,14 @@ def __init__(self, total_num_vals, L, var_info,D,M,natural):
decoder_net = nn.Sequential(nn.Linear(L, M), nn.LeakyReLU(),
nn.Linear(M, M), nn.LeakyReLU(),
nn.Linear(M, total_num_vals))

encoder_net.to(device)
decoder_net.to(device)

self.encoder = Encoder(encoder_net=encoder_net)

#TODO: num_vals should be changed according to the num_classes in said feature --> i.e. multiple encoder/decoders per attribute (multi-head)
self.decoder = Decoder(var_info=var_info, decoder_net=decoder_net, total_num_vals=total_num_vals,natural=natural)
self.decoder = Decoder(var_info=var_info, decoder_net=decoder_net, total_num_vals=total_num_vals,natural=natural, device=device)

#self.heads = nn.ModuleList([
# HIVAEHead(dist, hparams.size_s, hparams.size_z, hparams.size_y) for dist in prob_model
Expand All @@ -270,10 +277,13 @@ def __init__(self, total_num_vals, L, var_info,D,M,natural):

self.var_info = var_info # contains type of likelihood for variables

self.device = device

def forward(self, x, reduction='avg'):
# encoder
mu_e, log_var_e = self.encoder.encode(x)
z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
z = z.to(self.device)

#x_params = [head(y_shared, s_samples) for head in self.heads]

Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from utils import evaluation, samples_generated

def training(name, max_patience, num_epochs, model, optimizer, train_loader, val_loader, var_info, natural):
def training(name, max_patience, num_epochs, model, optimizer, train_loader, val_loader, var_info, natural, device):
nll_val = []
best_nll = 1000.
patience = 0
Expand All @@ -17,6 +17,8 @@ def training(name, max_patience, num_epochs, model, optimizer, train_loader, val
if hasattr(model, 'dequantization'):
if model.dequantization:
batch = batch + torch.rand(batch.shape)
batch = batch.to(device)


# batch[0] -> numerical data
# batch[1] -> categorical data
Expand All @@ -39,7 +41,7 @@ def training(name, max_patience, num_epochs, model, optimizer, train_loader, val
optimizer.step()

# Validation
loss_val = evaluation(val_loader, var_info, model_best=model, epoch=e,natural=natural)
loss_val = evaluation(val_loader, var_info, model_best=model, epoch=e,natural=natural,device=device)
nll_val.append(loss_val) # save for plotting

if e == 0:
Expand Down
4 changes: 3 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from models import VAE

def evaluation(test_loader, var_info, name=None, model_best=None, epoch=None, M=256,natural=False):
def evaluation(test_loader, var_info, name=None, model_best=None, epoch=None, M=256,natural=False,device=None):
# EVALUATION
if model_best is None:
D = len(var_info.keys())
Expand All @@ -12,13 +12,15 @@ def evaluation(test_loader, var_info, name=None, model_best=None, epoch=None, M=
for var in var_info.keys():
total_num_vals += var_info[var]['num_vals']
model_best = VAE(total_num_vals=total_num_vals, L=L, var_info = var_info, D=D, M=M, natural=natural)
model_best.to(device)
# load best performing model
model_best.load_state_dict(torch.load(name+'.model'))

model_best.eval()
loss = 0.
N = 0.
for indx_batch, test_batch in enumerate(test_loader):
test_batch = test_batch.to(device)
#test_batch = torch.stack(test_batch[1]).float() # TODO: To access only one attribute - only needed as long as no multi-head
# TODO: adjust to normal batch
#numerical = test_batch[0].float()
Expand Down

0 comments on commit 73d0d39

Please sign in to comment.