Skip to content

Commit

Permalink
[CLEAN] cleanup to have only rate and distortion
Browse files Browse the repository at this point in the history
  • Loading branch information
YannDubs committed Dec 5, 2020
1 parent 9f341da commit f472a38
Show file tree
Hide file tree
Showing 32 changed files with 448 additions and 323 deletions.
24 changes: 11 additions & 13 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@ defaults:

- data: toyMiniFashionMnist
- encoder: cnn
- decoder: mlp
- loss: vae
- coder: MI_unitgaussian
- distortion: vae
- rate: MI_unitgaussian

### GENERAL ###
name: ???
seed: 123
is_debug: False # enter debug mode
time: ${now:%Y-%m-%d_%H-%M-%S}

long_name: ${name}/d_${data.name}/l_${loss.name}/e_${encoder.name}/d_${decoder.name}/c_${coder.name}/o_${optimizer.name}/dim_${encoder.z_dim}/z_${loss.n_z_samples}/b_${loss.kwargs.beta}/s_${seed}/${time}
long_name: ${name}/data_${data.name}/d_${distortion.name}/e_${encoder.name}/r_${rate.name}/o_${optimizer.name}/dim_${encoder.z_dim}/z_${loss.n_z_samples}/b_${loss.beta}/s_${seed}/${time}

paths:
base_dir: ???
Expand All @@ -28,8 +27,8 @@ optimizer: # might need to be a group at some point
lr: 1e-3
scheduling_factor: 100 # by how much to reduce lr during training

# only used if coder needs an optimizer
optimizer_coder:
# only used if coder needs an optimizer (for coding)
optimizer_coder :
name: adam # not used yet but can change if needed
lr: 1e-3
scheduling_factor: 1 # by how much to reduce lr during training
Expand Down Expand Up @@ -110,22 +109,21 @@ encoder:
fam: diaggaussian
fam_kwargs: {}

decoder:
distortion:
name: ???
arch: ???
out_shape: ${data.target_aux_shape}
mode: ???
factor_beta : 1 # factor that multiplies beta
arch_kwargs:
complexity: ???

coder:
rate:
name: ???
range_coder: null
factor_beta : 1 # factor that multiplies beta
kwargs: {}

loss:
name: ???
n_z_samples: 1 # number of samples tu use inside log (like IWAE)
kwargs:
beta: 1
beta: 1

### HYDRA ###
5 changes: 0 additions & 5 deletions config/decoder/cnn.yaml

This file was deleted.

5 changes: 0 additions & 5 deletions config/decoder/mlp.yaml

This file was deleted.

16 changes: 16 additions & 0 deletions config/distortion/gvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_
distortion:
name: gvae
mode: direct
kwargs:
z_dim: ${encoder.z_dim}
y_shape: ${data.target_aux_shape}
is_classification: ${data.is_classification_aux}
arch: null
arch_kwargs:
complexity: 2

data:
kwargs:
dataset_kwargs:
additional_target: "representative"
16 changes: 16 additions & 0 deletions config/distortion/gvib.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_
distortion:
name: gvib
mode: direct
kwargs:
z_dim: ${encoder.z_dim}
y_shape: ${data.target_aux_shape}
is_classification: ${data.is_classification_aux}
arch: null
arch_kwargs:
complexity: 2

data:
kwargs:
dataset_kwargs:
additional_target: "max_inv"
9 changes: 9 additions & 0 deletions config/distortion/nce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# @package _global_
distortion:
name: nce
mode: contrastive

data:
kwargs:
dataset_kwargs:
additional_target: other_representative
16 changes: 16 additions & 0 deletions config/distortion/taskvib.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_
distortion:
name: taskvib
mode: direct
kwargs:
z_dim: ${encoder.z_dim}
y_shape: ${data.target_aux_shape}
is_classification: ${data.is_classification_aux}
arch: null
arch_kwargs:
complexity: 2

data:
kwargs:
dataset_kwargs:
additional_target: target
16 changes: 16 additions & 0 deletions config/distortion/vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_
distortion:
name: vae
mode: direct
kwargs:
z_dim: ${encoder.z_dim}
y_shape: ${data.target_aux_shape}
is_classification: ${data.is_classification_aux}
arch: null
arch_kwargs:
complexity: 2

data:
kwargs:
dataset_kwargs:
additional_target: "input"
16 changes: 16 additions & 0 deletions config/distortion/vib.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_
distortion:
name: vib
mode: direct
kwargs:
z_dim: ${encoder.z_dim}
y_shape: ${data.target_aux_shape}
is_classification: ${data.is_classification_aux}
arch: null
arch_kwargs:
complexity: 2

data:
kwargs:
dataset_kwargs:
additional_target: "idx"
2 changes: 1 addition & 1 deletion config/encoder/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @package _group_
name: mlp
arch: flattenmlp # uses flattenmlp because giving image (3D tensor) as input
arch: mlp # uses flattenmlp because giving image (3D tensor) as input
arch_kwargs:
complexity: 2
12 changes: 0 additions & 12 deletions config/loss/gvae.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions config/loss/gvib.yaml

This file was deleted.

10 changes: 0 additions & 10 deletions config/loss/probssl.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions config/loss/taskvib.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions config/loss/vae.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions config/loss/vib.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions config/mode/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ trainer:
overfit_batches: 0.05 # use 0.01 to make sure you can overfit 1% of training data => training works
weights_summary: top # full to print show the entire model
profiler: simple # use `simple` or `"advanced"` to find bottleneck
max_epochs: 5


logger:
wandb:
tags: debug
anonymous: true

callbacks:
gpu_stats:
Expand Down
2 changes: 1 addition & 1 deletion config/coder/CMI_vamp.yaml → config/rate/CMI_vamp.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @package _global_
coder:
rate:
name: CMI_vamp
kwargs:
prior_fam: vamp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @package _global_
coder:
rate:
name: H_factorized
kwargs:
init_scale: 1 # input should be at initialization in [-init_scale, init_scale]. Larger might create overhead. You can use something like 10 for large images but smaller for smaller
Expand Down
2 changes: 1 addition & 1 deletion config/coder/H_hyper.yaml → config/rate/H_hyper.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @package _global_
coder:
rate:
name: H_hyper
kwargs:
init_scale: 1 # input should be at initialization in [-init_scale, init_scale]. Larger might create overhead. You can use something like 10 for large images but smaller for smaller
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions lossyless/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .compressors import *
from .coders import *
from .losses import *
from .rates import *
from .distortions import *
from .predictors import *
from .architectures import *
from .distributions import *
2 changes: 0 additions & 2 deletions lossyless/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def get_Architecture(mode, complexity=2, **kwargs):
"""
if mode == "mlp":
# width 8,32,128,512,2048
return partial(MLP, hid_dim=8 * (4 ** (complexity)), **kwargs)
elif mode == "flattenmlp":
return partial(FlattenMLP, hid_dim=8 * (4 ** (complexity)), **kwargs)
elif mode == "resnet":
base = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet150"]
Expand Down
4 changes: 2 additions & 2 deletions lossyless/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __init__(self):
def on_train_epoch_end(self, trainer, pl_module, outputs):

#! waiting for torch lighning #1243
x_hat = pl_module._save["y_hat"].float()
x = pl_module._save["target"].float()
x_hat = pl_module._save["rec_img"].float()
x = pl_module._save["real_img"].float()
# undo normalization for plotting
x_hat, x = undo_normalization(x_hat, x, pl_module.hparams.data.dataset)
caption = f"ep: {trainer.current_epoch}"
Expand Down
Loading

0 comments on commit f472a38

Please sign in to comment.