Skip to content

Commit

Permalink
Commenting out old version of the autoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
silkemaes committed Oct 10, 2024
1 parent 980762a commit 53d37a9
Showing 1 changed file with 81 additions and 79 deletions.
160 changes: 81 additions & 79 deletions src/mace/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torchode as to # Lienen, M., & Günnemann, S. 2022, in The Symbiosis of Deep Learning and Differential Equations II, NeurIPS. https://openreview.net/forum?id=uiKVKTiUYB0
import src.mace.autoencoder as ae
import src.mace.latentODE as lODE
from scipy.stats import gmean
from time import time


Expand Down Expand Up @@ -199,87 +198,90 @@ def forward(self, n_0, p, tstep):
## ---------- OLD VERSION OF THE SOLVER CLASS ---------- ##
## This class is compatible with an older version of the autoencoder

class Solver_old(nn.Module):
'''
The Solver class presents the architecture of MACE.
Components:
1) Encoder; neural network with adjustable amount of nodes and layers
2) Neural ODE; ODE given by function g, with trainable elements
3) Decoder; neural network with adjustable amount of nodes and layers
'''
def __init__(self, p_dim, z_dim, DEVICE, n_dim=466, g_nn = False, atol = 1e-5, rtol = 1e-2):
super(Solver_old, self).__init__() # type: ignore

self.status_train = list()
self.status_test = list()

self.z_dim = z_dim
self.n_dim = n_dim
self.DEVICE = DEVICE
self.g_nn = g_nn

## Setting the neural ODE
input_ae_dim = n_dim
if not self.g_nn:
self.g = lODE.G(z_dim)
input_ae_dim = input_ae_dim+p_dim
self.odeterm = to.ODETerm(self.g, with_args=False)
if self.g_nn:
self.g = lODE.Gnn(p_dim, z_dim)
self.odeterm = to.ODETerm(self.g, with_args=True)

self.step_method = to.Dopri5(term=self.odeterm)
self.step_size_controller = to.IntegralController(atol=atol, rtol=rtol, term=self.odeterm)
self.adjoint = to.AutoDiffAdjoint(self.step_method, self.step_size_controller).to(self.DEVICE) # type: ignore

self.jit_solver = torch.compile(self.adjoint)

## Setting the autoencoder (enocder + decoder)
hidden_ae_dim = int(gmean([input_ae_dim, z_dim]))
self.encoder = ae.Encoder_old(input_dim=input_ae_dim, hidden_dim=hidden_ae_dim, latent_dim=z_dim)
self.decoder = ae.Decoder_old(latent_dim=z_dim , hidden_dim=hidden_ae_dim, output_dim=n_dim)

def set_status(self, status, phase):
if phase == 'train':
self.status_train.append(status)
elif phase == 'test':
self.status_test.append(status)

def get_status(self, phase):
if phase == 'train':
return np.array(self.status_train)
elif phase == 'test':
return np.array(self.status_test)


def forward(self, n_0, p, tstep):
'''
Forward function giving the workflow of the MACE architecture.
'''

x_0 = n_0 ## use NN version of G
if not self.g_nn: ## DON'T use NN version of G
## Ravel the abundances n_0 and physical input p to x_0
x_0 = torch.cat((p, n_0), axis=-1) # type: ignore

## Encode x_0, returning the encoded z_0 in latent space
z_0 = self.encoder(x_0)
# from scipy.stats import gmean

# class Solver_old(nn.Module):
# '''
# The Solver class presents the architecture of MACE.
# Components:
# 1) Encoder; neural network with adjustable amount of nodes and layers
# 2) Neural ODE; ODE given by function g, with trainable elements
# 3) Decoder; neural network with adjustable amount of nodes and layers

# '''
# def __init__(self, p_dim, z_dim, DEVICE, n_dim=466, g_nn = False, atol = 1e-5, rtol = 1e-2):
# super(Solver_old, self).__init__() # type: ignore

# self.status_train = list()
# self.status_test = list()

# self.z_dim = z_dim
# self.n_dim = n_dim
# self.DEVICE = DEVICE
# self.g_nn = g_nn

# ## Setting the neural ODE
# input_ae_dim = n_dim
# if not self.g_nn:
# self.g = lODE.G(z_dim)
# input_ae_dim = input_ae_dim+p_dim
# self.odeterm = to.ODETerm(self.g, with_args=False)
# if self.g_nn:
# self.g = lODE.Gnn(p_dim, z_dim)
# self.odeterm = to.ODETerm(self.g, with_args=True)

# self.step_method = to.Dopri5(term=self.odeterm)
# self.step_size_controller = to.IntegralController(atol=atol, rtol=rtol, term=self.odeterm)
# self.adjoint = to.AutoDiffAdjoint(self.step_method, self.step_size_controller).to(self.DEVICE) # type: ignore

# self.jit_solver = torch.compile(self.adjoint)

# ## Setting the autoencoder (enocder + decoder)
# hidden_ae_dim = int(gmean([input_ae_dim, z_dim]))
# self.encoder = ae.Encoder_old(input_dim=input_ae_dim, hidden_dim=hidden_ae_dim, latent_dim=z_dim)
# self.decoder = ae.Decoder_old(latent_dim=z_dim , hidden_dim=hidden_ae_dim, output_dim=n_dim)

# def set_status(self, status, phase):
# if phase == 'train':
# self.status_train.append(status)
# elif phase == 'test':
# self.status_test.append(status)

# def get_status(self, phase):
# if phase == 'train':
# return np.array(self.status_train)
# elif phase == 'test':
# return np.array(self.status_test)


# def forward(self, n_0, p, tstep):
# '''
# Forward function giving the workflow of the MACE architecture.
# '''

# x_0 = n_0 ## use NN version of G
# if not self.g_nn: ## DON'T use NN version of G
# ## Ravel the abundances n_0 and physical input p to x_0
# x_0 = torch.cat((p, n_0), axis=-1) # type: ignore

# ## Encode x_0, returning the encoded z_0 in latent space
# z_0 = self.encoder(x_0)

## Create initial value problem
problem = to.InitialValueProblem(
y0 = z_0.to(self.DEVICE), ## "view" is om met de batches om te gaan
t_eval = tstep.view(z_0.shape[0],-1).to(self.DEVICE),
)
# ## Create initial value problem
# problem = to.InitialValueProblem(
# y0 = z_0.to(self.DEVICE), ## "view" is om met de batches om te gaan
# t_eval = tstep.view(z_0.shape[0],-1).to(self.DEVICE),
# )

## Solve initial value problem. Details are set in the __init__() of this class.
solution = self.jit_solver.solve(problem, args=p)
z_s = solution.ys.view(-1, self.z_dim) ## want batches
# ## Solve initial value problem. Details are set in the __init__() of this class.
# solution = self.jit_solver.solve(problem, args=p)
# z_s = solution.ys.view(-1, self.z_dim) ## want batches

## Decode the resulting values from latent space z_s back to physical space
n_s_ravel = self.decoder(z_s)
# ## Decode the resulting values from latent space z_s back to physical space
# n_s_ravel = self.decoder(z_s)

## Reshape correctly
n_s = n_s_ravel.reshape(1,tstep.shape[-1], self.n_dim)
# ## Reshape correctly
# n_s = n_s_ravel.reshape(1,tstep.shape[-1], self.n_dim)

return n_s, z_s, solution.status
# return n_s, z_s, solution.status

0 comments on commit 53d37a9

Please sign in to comment.