Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting the updated examples from 'main' to 'workshop' #4

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Approximate inference targeted at variational Gaussian state-space models with d


### introduction
The core of a large scale variational smoothing (LSVS) module is a LowRankNonlinearStateSpaceModel object. A LowRankNonlinearStateSpaceModel is used to perform inference in a state-space graphical model specified by,
A LowRankNonlinearStateSpaceModel object is used to perform inference in a state-space graphical model specified by,

$$p(y_{1:T}, z_{1:T}) = p_{\theta}(z_1) p(y_1 | z_1) \prod p_{\psi}(y_t | z_t) p_{\theta}(z_t | z_{t-1})$$
where
Expand Down
4 changes: 2 additions & 2 deletions examples/lds_example/inference_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
from xfads.smoothers.nonlinear_smoother import NonlinearFilter, LowRankNonlinearStateSpaceModel
from xfads.smoothers.nonlinear_smoother import NonlinearFilterSmallL, LowRankNonlinearStateSpaceModel


def main():
Expand Down Expand Up @@ -66,7 +66,7 @@ def main():
device=cfg.device)
local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
device=cfg.device, dropout=cfg.p_local_dropout)
nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)
nl_filter = NonlinearFilterSmallL(dynamics_mod, initial_condition_pdf, device=cfg.device)

"""sequence vae"""
ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
Expand Down
29 changes: 13 additions & 16 deletions examples/monkey_reaching/config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
# --- graphical model --- #
n_latents: 2
n_latents_read: 2
n_latents: 40
n_latents_read: 35

rank_local: 2
rank_backward: 2
rank_local: 15
rank_backward: 5

Q_init: 1e-1
n_hidden_dynamics: 64
n_hidden_dynamics: 128

# --- inference network --- #
n_samples: 5
n_hidden_local: 128
n_hidden_backward: 64
n_samples: 25
n_hidden_local: 256
n_hidden_backward: 128

# --- hyperparameters --- #
use_cd: False
p_mask_a: 0.5
p_mask_a: 0.0
p_mask_b: 0.0
p_mask_apb: 0.0
p_mask_y_in: 0.0
Expand All @@ -27,17 +26,15 @@ device: 'cpu'
data_device: 'cpu'

lr: 1e-3
n_epochs: 1500
batch_sz: 128
n_epochs: 1000
batch_sz: 512

# --- misc --- #
bin_sz: 20e-3
bin_sz_ms: 20

seed: 1234
seed: 1236
default_dtype: torch.float32

# --- ray --- #
n_ray_samples: 10


n_ray_samples: 10
51 changes: 51 additions & 0 deletions examples/monkey_reaching/download_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import h5py
import torch
import requests

from pathlib import Path


def main():
Path("data").mkdir(parents=True, exist_ok=True)
url = 'https://github.com/arsedler9/lfads-torch/blob/main/datasets/mc_maze-20ms-val.h5'
r = requests.get(url, allow_redirects=True)
open('data/mc_maze_20ms.h5', 'wb').write(r.content)

bin_sz = 20
save_root_path = 'data/'
data_path = 'data/mc_maze_20ms.h5'

with h5py.File(data_path, "r") as h5file:
data_dict = {k: v[()] for k, v in h5file.items()}

train_data, valid_data, test_data = {}, {}, {}
seq_len = data_dict['train_encod_data'].shape[1]
n_valid_trials = data_dict['valid_recon_data'].shape[0]

train_data['y_obs'] = torch.Tensor(data_dict['train_recon_data'])
train_data['velocity'] = torch.Tensor(data_dict['train_behavior'])
train_data['n_neurons_enc'] = data_dict['train_encod_data'].shape[-1]
train_data['n_neurons_obs'] = data_dict['train_recon_data'].shape[-1]
train_data['n_time_bins_enc'] = seq_len

valid_data['y_obs'] = torch.Tensor(data_dict['valid_recon_data'])[:n_valid_trials//2]
valid_data['velocity'] = torch.Tensor(data_dict['valid_behavior'])[:n_valid_trials//2]
valid_data['n_neurons_enc'] = data_dict['valid_encod_data'].shape[-1]
train_data['n_neurons_obs'] = data_dict['valid_recon_data'].shape[-1]
valid_data['n_time_bins_enc'] = seq_len

test_data['y_obs'] = torch.Tensor(data_dict['valid_recon_data'])[n_valid_trials//2:]
test_data['velocity'] = torch.Tensor(data_dict['valid_behavior'])[n_valid_trials//2:]
test_data['n_neurons_enc'] = data_dict['valid_encod_data'].shape[-1]
test_data['n_neurons_obs'] = data_dict['valid_recon_data'].shape[-1]
test_data['n_time_bins_enc'] = seq_len

torch.save(train_data, save_root_path + f'data_train_{bin_sz}ms.pt')
torch.save(valid_data, save_root_path + f'data_valid_{bin_sz}ms.pt')
torch.save(test_data, save_root_path + f'data_test_{bin_sz}ms.pt')
print(f'train shape: {train_data["y_obs"].shape}')
print(f'valid shape: {valid_data["y_obs"].shape}')


if __name__ == '__main__':
main()
146 changes: 49 additions & 97 deletions examples/monkey_reaching/inference_smoother_acausal.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,65 @@
import os
os.environ["OMP_NUM_THREADS"] = "8" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "8" # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "8" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "8" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "8" # export NUMEXPR_NUM_THREADS=6

import math
import torch
import torch.nn as nn
import xfads.utils as utils
import xfads.prob_utils as prob_utils
import pytorch_lightning as lightning
import matplotlib.pyplot as plt

from hydra import compose, initialize
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from xfads.ssm_modules.likelihoods import PoissonLikelihood
from xfads.ssm_modules.dynamics import DenseGaussianDynamics
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM, LightningMonkeyReaching
from xfads.smoothers.nonlinear_smoother import NonlinearFilter, LowRankNonlinearStateSpaceModel
# from dev.smoothers.nonlinear_smoother_causal_debug import NonlinearFilter, LowRankNonlinearStateSpaceModel
from xfads.smoothers.lightning_trainers import LightningMonkeyReaching
from xfads.ssm_modules.prebuilt_models import create_xfads_poisson_log_link


def main():
# at t=n_bins_bhv start forecast
n_bins_bhv = 10

torch.cuda.empty_cache()
initialize(version_base=None, config_path="", job_name="monkey_reaching")
cfg = compose(config_name="config")

n_bins_bhv = 10
seeds = [1234, 1235, 1236]

"""config"""
for seed in seeds:
cfg.seed = seed

lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

"""data"""
data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)
vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)
n_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_time_bins_enc = train_data['n_time_bins_enc']

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)
train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

"""likelihood pdf"""
H = utils.ReadoutLatentMask(cfg.n_latents, cfg.n_latents_read)
readout_fn = nn.Sequential(H, nn.Linear(cfg.n_latents_read, n_neurons_obs))
readout_fn[-1].bias.data = prob_utils.estimate_poisson_rate_bias(train_dataloader, cfg.bin_sz)
likelihood_pdf = PoissonLikelihood(readout_fn, n_neurons_obs, cfg.bin_sz, device=cfg.device)

"""dynamics module"""
Q_diag = 1. * torch.ones(cfg.n_latents, device=cfg.device)
dynamics_fn = utils.build_gru_dynamics_function(cfg.n_latents, cfg.n_hidden_dynamics, device=cfg.device)
dynamics_mod = DenseGaussianDynamics(dynamics_fn, cfg.n_latents, Q_diag, device=cfg.device)

"""initial condition"""
m_0 = torch.zeros(cfg.n_latents, device=cfg.device)
Q_0_diag = 1. * torch.ones(cfg.n_latents, device=cfg.device)
initial_condition_pdf = DenseGaussianInitialCondition(cfg.n_latents, m_0, Q_0_diag, device=cfg.device)

"""local/backward encoder"""
backward_encoder = BackwardEncoderLRMvn(cfg.n_latents, cfg.n_hidden_backward, cfg.n_latents,
rank_local=cfg.rank_local, rank_backward=cfg.rank_backward,
device=cfg.device)
local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons_obs, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
device=cfg.device, dropout=cfg.p_local_dropout)
nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)

"""sequence vae"""
ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
local_encoder, nl_filter, device=cfg.device)

"""lightning"""
seq_vae = LightningMonkeyReaching(ssm, cfg, n_time_bins_enc, n_bins_bhv)
csv_logger = CSVLogger('logs/smoother/acausal/', name=f'sd_{cfg.seed}_r_y_{cfg.rank_local}_r_b_{cfg.rank_backward}', version='smoother_acausal')
ckpt_callback = ModelCheckpoint(save_top_k=3, monitor='r2_valid_enc', mode='max', dirpath='ckpts/smoother/acausal/', save_last=True,
filename='{epoch:0}_{valid_loss:0.2f}_{r2_valid_enc:0.2f}_{r2_valid_bhv:0.2f}_{valid_bps_enc:0.2f}')

trainer = lightning.Trainer(max_epochs=cfg.n_epochs,
gradient_clip_val=1.0,
default_root_dir='lightning/',
callbacks=[ckpt_callback],
logger=csv_logger,
strategy='ddp',
accelerator='gpu',
)

trainer.fit(model=seq_vae, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
torch.save(ckpt_callback.best_model_path, 'ckpts/smoother/acausal/best_model_path.pt')
trainer.test(dataloaders=test_dataloader, ckpt_path='last')
lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

"""data"""
data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'
train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)
vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)
n_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_time_bins_enc = train_data['n_time_bins_enc']

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)
train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

"""create ssm"""
ssm = create_xfads_poisson_log_link(cfg, n_neurons_obs, train_dataloader, model_type='n')

"""lightning"""
seq_vae = LightningMonkeyReaching(ssm, cfg, n_time_bins_enc, n_bins_bhv)
csv_logger = CSVLogger('logs/smoother/acausal/', name=f'sd_{cfg.seed}_r_y_{cfg.rank_local}_r_b_{cfg.rank_backward}', version='smoother_acausal')
ckpt_callback = ModelCheckpoint(save_top_k=3, monitor='r2_valid_enc', mode='max', dirpath='ckpts/smoother/acausal/', save_last=True,
filename='{epoch:0}_{valid_loss:0.2f}_{r2_valid_enc:0.2f}_{r2_valid_bhv:0.2f}_{valid_bps_enc:0.2f}')

trainer = lightning.Trainer(max_epochs=cfg.n_epochs,
gradient_clip_val=1.0,
default_root_dir='lightning/',
callbacks=[ckpt_callback],
logger=csv_logger,
)

trainer.fit(model=seq_vae, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
torch.save(ckpt_callback.best_model_path, 'ckpts/smoother/acausal/best_model_path.pt')
trainer.test(dataloaders=test_dataloader, ckpt_path='last')


if __name__ == '__main__':
Expand Down
Loading