Skip to content

Commit

Permalink
results: update GMM
Browse files Browse the repository at this point in the history
  • Loading branch information
laurence committed Aug 3, 2022
1 parent ecc9de6 commit 02c969b
Show file tree
Hide file tree
Showing 32 changed files with 59 additions and 792 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
*.idea/
*__pycache__/
/examples/paper_results/gmm/models/
/examples/paper_results/many_well/models/
/examples/many_well/models/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ python examples/gmm/run.py training.use_buffer=True training.prioritised_buffer=
To run the full set of experiments see the [README](../examples/gmm/README.md) for the GMM experiments.

The below plot shows samples from various trained models, with the GMM problem target contours in the background.
![Gaussian Mixture Model samples vs contours](./examples/paper_results/gmm/plots/MoG.png)
![Gaussian Mixture Model samples vs contours](examples/gmm/plots/MoG.png)

### Many Well distribution
The 32 Many Well distribution is made up of 16 repeats of the Double Well distribution,
Expand All @@ -37,7 +37,7 @@ To run the full set of experiments see the [README](./examples/many_well/README.
The below plot shows samples for our model (FAB) vs training a flow by reverse KL divergence
minimisation, with the Many Well problem target contours in the background.
This visualisation is for the marginal pairs of the distributions for the first four elements of the x.
![Many Well distribution FAB vs training by KL divergence minimisation](./examples/paper_results/many_well/plots/many_well.png)
![Many Well distribution FAB vs training by KL divergence minimisation](examples/many_well/plots/many_well.png)

**Alanine Dipeptide distribution**

Expand Down
33 changes: 33 additions & 0 deletions examples/gmm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# GMM Problem
## Experiments
The following commands can each be used to train the methods from the paper:
```
# FAB with prioritised buffer.
python examples/gmm/run.py training.seed=0,1,2 training.use_buffer=True training.prioritised_buffer=True
# FAB without the prioritised buffer.
python examples/gmm/run.py training.seed=0,1,2 fab.loss_type=p2_over_q_alpha_2_div
# Flow using ground truth samples, training by maximum likelihood/forward KL divergence minimiation.
python examples/gmm/run.py training.seed=0,1,2 fab.loss_type=target_forward_kl
# Flow using alpha-divergence, with alpha=2
python examples/gmm/run.py training.seed=0,1,2 fab.loss_type=flow_alpha_2_div_nis
# Flow using reverse KL divergence
python examples/gmm/run.py training.seed=0,1,2 fab.loss_type=flow_reverse_kld
# SNF using reverse KLD
python examples/gmm/run.py training.seed=0,1,2 flow.use_snf=True
```

**Further notes** This will use hydra-multirun to run the random seeds in parallel.
However, if you just want to run locally and get a general idea of the results,
you can run a single random seed for a much lower number of iterations.
The config file for this experiment is [here](../config/gmm.yaml), where you can change the hyper-parameters.

## Evaluation
Trained models may be evaluated using the code in
[`evaluation.py`](evaluation.py) and [`evaluation_expectation_quadratic_func.py`](evaluation_expectation_quadratic_func.py).
Furthermore [`results_vis.py`](results_vis.py) may be used to obtain the plot from the paper
visualising each of the modes.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def evaluate(cfg: DictConfig, model_name: str, num_samples=int(1e4)):
return eval


@hydra.main(config_path="../../config", config_name="gmm.yaml")
@hydra.main(config_path="../config", config_name="gmm.yaml")
def main(cfg: DictConfig):
model_names = ["fab_buffer", "fab_no_buffer", "flow_kld", "flow_nis", "target_kld", "snf"]
seeds = [1, 2, 3]
Expand All @@ -89,7 +89,7 @@ def main(cfg: DictConfig):
results.to_csv(open(FILENAME_EVAL_INFO, "w"))


FILENAME_EVAL_INFO = "/home/laurence/work/code/FAB-TORCH/examples/paper_results/gmm/gmm_results.csv"
FILENAME_EVAL_INFO = "/examples/paper_results/gmm/gmm_results.csv"

if __name__ == '__main__':
main()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def evaluate(cfg: DictConfig, model_name: str, num_samples=int(1e3), n_repeats=1
return info


@hydra.main(config_path="../../config", config_name="gmm.yaml")
@hydra.main(config_path="../config", config_name="gmm.yaml")
def main(cfg: DictConfig):
model_names = ["target", "fab_buffer", "fab_no_buffer", "flow_kld", "flow_nis", "target_kld", "snf"]
seeds = [1, 2, 3]
Expand All @@ -118,7 +118,7 @@ def main(cfg: DictConfig):
print(results.groupby("model_name").std()[fields])
results.to_csv(open(FILENAME_EXPECTATION_INFO, "w"))

FILENAME_EXPECTATION_INFO = "/home/laurence/work/code/FAB-TORCH/examples/paper_results/gmm/gmm_results_expectation.csv"
FILENAME_EXPECTATION_INFO = "/examples/paper_results/gmm/gmm_results_expectation.csv"

if __name__ == '__main__':
if True:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd
from examples.paper_results.gmm.evaluation import FILENAME_EVAL_INFO
from examples.paper_results.gmm.evaluation_expectation_quadratic_func import FILENAME_EXPECTATION_INFO
from examples.gmm.evaluation import FILENAME_EVAL_INFO
from examples.gmm.evaluation_expectation_quadratic_func import FILENAME_EXPECTATION_INFO



Expand Down
File renamed without changes
Binary file added examples/gmm/plots/MoG_appendix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def plot_result(cfg: DictConfig, ax: plt.axes, target, model_name: Optional[str]
plot_marginal_pair(samples_flow, ax=ax, bounds=plotting_bounds, alpha=alpha)


@hydra.main(config_path="../../config", config_name="gmm.yaml")
@hydra.main(config_path="../config", config_name="gmm.yaml")
def run(cfg: DictConfig):
appendix = True
if appendix:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

PATH = os.getcwd()

@hydra.main(config_path="./", config_name="config.yaml")
@hydra.main(config_path="/", config_name="config.yaml")
def run(cfg: DictConfig):
dim = cfg.target.dim
target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from examples.setup_run_snf import make_normflow_snf_model, SNFModel

from fab.target_distributions.many_well import ManyWellEnergy
from examples.paper_results.many_well.old_target_many_well import ManyWellEnergy as OldManyWellEnergy
from examples.many_well.old_target_many_well import ManyWellEnergy as OldManyWellEnergy
import pandas as pd
import os
from omegaconf import DictConfig
import torch

from fab import FABModel, HamiltonianMonteCarlo, Metropolis
from fab import FABModel, HamiltonianMonteCarlo
from examples.make_flow import make_wrapped_normflowdist


Expand Down Expand Up @@ -87,7 +87,7 @@ def evaluate(cfg: DictConfig, model_name: str, target, num_samples=int(5e4)):
return eval


@hydra.main(config_path="../../config", config_name="many_well.yaml")
@hydra.main(config_path="../config", config_name="many_well.yaml")
def main(cfg: DictConfig):
model_names = ["target_kld"] # ["fab_buffer", "fab_no_buffer", "flow_kld", "flow_nis", "snf"]
seeds = [1, 2, 3]
Expand Down Expand Up @@ -121,7 +121,7 @@ def main(cfg: DictConfig):
print("overall results")
print(results[["model_name", "seed"] + keys])

FILENAME_EVAL_INFO = "/home/laurence/work/code/FAB-TORCH/examples/paper_results/many_well/many_well_results_ml_only.csv"
FILENAME_EVAL_INFO = "/examples/many_well/many_well_results_ml_only.csv"


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from examples.paper_results.many_well.evaluation import FILENAME_EVAL_INFO
from examples.many_well.evaluation import FILENAME_EVAL_INFO



Expand Down
2 changes: 1 addition & 1 deletion examples/many_well.py → examples/many_well/many_well.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _run(cfg: DictConfig):
setup_trainer_and_run_snf(cfg, setup_many_well_plotter, target)


@hydra.main(config_path="./config", config_name="many_well.yaml")
@hydra.main(config_path="../config", config_name="many_well.yaml")
def run(cfg: DictConfig):
_run(cfg)

Expand Down
File renamed without changes.
File renamed without changes
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from typing import Optional
import hydra
import matplotlib.pyplot as plt
from matplotlib import rc
import matplotlib as mpl
from omegaconf import DictConfig
from examples.make_flow import make_wrapped_normflowdist
from examples.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair
from examples.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
import torch
Expand Down Expand Up @@ -55,7 +54,7 @@ def plot_marginals(cfg: DictConfig, supfig, model_name, plot_y_label):



@hydra.main(config_path="./", config_name="config.yaml")
@hydra.main(config_path="/", config_name="config.yaml")
def run(cfg: DictConfig):
mpl.rcParams['figure.dpi'] = 300
rc('font', **{'family': 'serif', 'serif': ['Times']})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
from typing import Optional
import hydra
import matplotlib.pyplot as plt
from matplotlib import rc
import matplotlib as mpl
from omegaconf import DictConfig
from examples.make_flow import make_wrapped_normflowdist
from examples.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair
from examples.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
from examples.paper_results.many_well.old_target_many_well import ManyWellEnergy as OldManyWellEnergy
from examples.many_well.old_target_many_well import ManyWellEnergy as OldManyWellEnergy
import torch
from examples.setup_run_snf import make_normflow_snf_model, SNFModel

Expand Down Expand Up @@ -82,7 +81,7 @@ def plot_marginals(cfg: DictConfig, supfig, model_name, plot_y_label):



@hydra.main(config_path="./", config_name="config.yaml")
@hydra.main(config_path="/", config_name="config.yaml")
def run(cfg: DictConfig):
mpl.rcParams['figure.dpi'] = 300
rc('font', **{'family': 'serif', 'serif': ['Times']})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
from typing import Optional
import hydra
import matplotlib.pyplot as plt
from matplotlib import rc
import matplotlib as mpl
from omegaconf import DictConfig
from examples.make_flow import make_wrapped_normflowdist
from examples.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair
from examples.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
from examples.paper_results.many_well.old_target_many_well import ManyWellEnergy as OldManyWellEnergy
from examples.many_well.old_target_many_well import ManyWellEnergy as OldManyWellEnergy
import torch
from examples.setup_run_snf import make_normflow_snf_model, SNFModel

Expand Down Expand Up @@ -83,7 +82,7 @@ def plot_marginals(cfg: DictConfig, supfig, model_name, plot_y_label):



@hydra.main(config_path="./", config_name="config.yaml")
@hydra.main(config_path="/", config_name="config.yaml")
def run(cfg: DictConfig):
mpl.rcParams['figure.dpi'] = 300
rc('font', **{'family': 'serif', 'serif': ['Times']})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def evaluate(cfg: DictConfig, target, num_samples=int(100), log_prob_scale_facto
return eval


@hydra.main(config_path="../../config", config_name="many_well.yaml")
@hydra.main(config_path="../config", config_name="many_well.yaml")
def main(cfg: DictConfig):
model_names = ["snf"]
log_prob_scale_factors = [0, 1, 10, 10, 70]
Expand Down
37 changes: 0 additions & 37 deletions examples/paper_results/gmm/Makefile

This file was deleted.

9 changes: 0 additions & 9 deletions examples/paper_results/gmm/README.md

This file was deleted.

19 changes: 0 additions & 19 deletions examples/paper_results/gmm/gmm_results.csv

This file was deleted.

22 changes: 0 additions & 22 deletions examples/paper_results/gmm/gmm_results_expectation.csv

This file was deleted.

Loading

0 comments on commit 02c969b

Please sign in to comment.