Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting the updated examples from 'main' to 'workshop' #4

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update lds and vdp example to use smallL filter
  • Loading branch information
matthew-dowling committed Jun 6, 2024
commit 389d4823302398c57c86f2b7ba67a452e4fcff72
4 changes: 2 additions & 2 deletions examples/lds_example/inference_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
from xfads.smoothers.nonlinear_smoother import NonlinearFilter, LowRankNonlinearStateSpaceModel
from xfads.smoothers.nonlinear_smoother import NonlinearFilterSmallL, LowRankNonlinearStateSpaceModel


def main():
Expand Down Expand Up @@ -66,7 +66,7 @@ def main():
device=cfg.device)
local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
device=cfg.device, dropout=cfg.p_local_dropout)
nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)
nl_filter = NonlinearFilterSmallL(dynamics_mod, initial_condition_pdf, device=cfg.device)

"""sequence vae"""
ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
Expand Down
13 changes: 2 additions & 11 deletions examples/vdp_example/inference_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import math
import torch
import torch.nn as nn
import xfads.utils as utils
import pytorch_lightning as lightning
import matplotlib.pyplot as plt

from hydra import compose, initialize
from pytorch_lightning.loggers import CSVLogger
Expand All @@ -13,13 +10,7 @@
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
# from dev.smoothers.nonlinear_smoother import NonlinearFilter, LowRankNonlinearStateSpaceModel
from xfads.smoothers.nonlinear_smoother_causal import NonlinearFilter, LowRankNonlinearStateSpaceModel

# from dev.smoothers.nonlinear_smoother_diagonal import NonlinearFilter
# from dev.ssm_modules.encoders import LocalEncoderDiagonal as LocalEncoderLRMvn
# from dev.ssm_modules.encoders import BackwardEncoderDiagonal as BackwardEncoderLRMvn
# from dev.smoothers.nonlinear_smoother_diagonal import DiagonalNonlinearStateSpaceModel as LowRankNonlinearStateSpaceModel
from xfads.smoothers.nonlinear_smoother import NonlinearFilterSmallL, LowRankNonlinearStateSpaceModel



Expand Down Expand Up @@ -72,7 +63,7 @@ def main():
device=cfg.device)
local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
device=cfg.device, dropout=cfg.p_local_dropout)
nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)
nl_filter = NonlinearFilterSmallL(dynamics_mod, initial_condition_pdf, device=cfg.device)

"""sequence vae"""
ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
Expand Down
6 changes: 4 additions & 2 deletions xfads/smoothers/nonlinear_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ def forward(self,
k: torch.Tensor,
K: torch.Tensor,
n_samples: int,
get_v: bool=False,
get_kl: bool=False,
p_mask: float=0.0):

# mask data, 0: data available, 1: data missing
Expand All @@ -372,8 +374,8 @@ def forward(self,

for t in range(n_time_bins):
if t == 0:
m_0 = self.initial_c_pdf.m0
P_0_diag = Fn.softplus(self.initial_c_pdf.log_v0)
m_0 = self.initial_c_pdf.m_0
P_0_diag = Fn.softplus(self.initial_c_pdf.log_Q_0)
z_f_t, m_f_t, P_f_chol_t, P_p_chol_t = filter_step_0(m_0, k[:, 0], K[:, 0], P_0_diag, n_samples)
m_p.append(m_0 * torch.ones(n_trials, n_latents, device=k[:, 0].device))
else:
Expand Down