Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion avae/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ class AffinityConfig(BaseModel):
description="Optimisation method.It can be adam/sgd/asgd",
pattern='^(adam|sgd|asgd)$',
)
pose_dims: PositiveInt = Field(1, description="Pose dimensions")
pose_dims: int = Field(1, description="Pose dimensions")

rescale: float = Field(None, description="Rescale data")
restart: bool = Field(False, description="Restart training")
shift_min: bool = Field(
Expand Down
2 changes: 1 addition & 1 deletion avae/decoders/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(

# add a final convolutional decoder to generate an image if the number
# of output channels has been provided
if output_channels != 0:
if output_channels is not 0:
conv = (
torch.nn.Conv2d
if self._ndim == SpatialDims.TWO
Expand Down
7 changes: 5 additions & 2 deletions avae/encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,11 @@ def forward(self, x):
encoded = self.encoder(x)
mu = self.mu(encoded)
log_var = self.log_var(encoded)
pose = self.pose_fc(encoded)
return mu, log_var, pose
if self.pose:
pose = self.pose_fc(encoded)
return mu, log_var, pose
else:
return mu, log_var


class EncoderB(AbstractEncoder):
Expand Down
25 changes: 8 additions & 17 deletions avae/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import settings, vis
from .data import load_data
from .utils import accuracy
from .utils import accuracy, latest_file
from .utils_learning import add_meta, pass_batch, set_device


Expand Down Expand Up @@ -85,16 +85,13 @@ def evaluate(
"There are no existing model states saved or provided via the state flag in config unable to evaluate."
)
else:
state = sorted(
[s for s in os.listdir("states") if ".pt" in s],
key=lambda x: int(x.split("_")[2][1:]),
)[-1]
state = latest_file("states", ".pt")
state = os.path.join("states", state)

s = os.path.basename(state)
fname = s.split(".")[0].split("_")
dshape = list(tests)[0][0].shape[2:]
pose_dims = fname[3]
pose_dims = int(fname[-1])

logging.info("Loading model from: {}".format(state))
checkpoint = torch.load(state)
Expand All @@ -105,14 +102,7 @@ def evaluate(
# ########################## EVALUATE ################################

if meta is None:
metas = sorted(
[
f
for f in os.listdir("states")
if ".pkl" in f and "eval" not in f
],
key=lambda x: int(x.split("_")[2][1:]),
)[-1]
metas = latest_file("states", ".pkl")
meta = os.path.join("states", metas)

logging.info("Loading model from: {}".format(meta))
Expand All @@ -122,8 +112,9 @@ def evaluate(
x_test = []
y_test = []
c_test = []
p_test = None

if pose_dims != 0:
if pose_dims is not 0:
p_test = []

logging.debug("Batch: [0/%d]" % (len(tests)))
Expand Down Expand Up @@ -179,7 +170,7 @@ def evaluate(
)

# visualise pose disentanglement
if pose_dims != 0 and settings.VIS_POS:
if pose_dims is not 0 and settings.VIS_POS:
vis.pose_disentanglement_plot(
dshape,
x_test,
Expand All @@ -189,7 +180,7 @@ def evaluate(
mode="_eval",
)

if pose_dims != 0 and settings.VIS_POSE_CLASS:
if pose_dims is not 0 and settings.VIS_POSE_CLASS:
vis.pose_class_disentanglement_plot(
dshape,
x_test,
Expand Down
20 changes: 20 additions & 0 deletions avae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,23 @@ def latent_space_similarity_mat(
] # symmetrical matrix

return cosine_sim_mat


def latest_file(path: str, extension: str) -> str:

most_recent_file = ""
most_recent_time = 0

# iterate over the files in the directory using os.scandir
for entry in os.scandir(path):
if entry.name.lower().endswith(extension):
# get the modification time of the file using entry.stat().st_mtime_ns
mod_time = entry.stat().st_mtime_ns
if (
mod_time > most_recent_time
and "eval" not in entry.name.lower()
):
# update the most recent file and its modification time
most_recent_file = entry.name
most_recent_time = mod_time
return most_recent_file
8 changes: 5 additions & 3 deletions avae/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def latent_embed_plot_tsne(
epoch: int = 0,
writer: typing.Any = None,
perplexity: int = 40,
marker_size: int = 24,
l_w: int = 2,
display: bool = False,
) -> None:
"""Plot static TSNE embedding.
Expand Down Expand Up @@ -253,7 +255,7 @@ def latent_embed_plot_tsne(
plt.scatter(
lats[idx, 0],
lats[idx, 1],
s=24,
s=marker_size,
label=mol[:4],
facecolor=color,
edgecolor=color,
Expand All @@ -277,6 +279,7 @@ def latent_embed_plot_tsne(
stacked=True,
fill=False,
label=mol[:4],
linewidth=l_w,
)
plt.legend(
prop={"size": 10},
Expand Down Expand Up @@ -1284,14 +1287,13 @@ def latent_4enc_interpolate_plot(
enc = []

draw_four = random.sample(range(len(classes)), k=4)
selected_classes = [classes[index] for index in draw_four]
for idx in draw_four:
lat = np.take(
xs,
random.sample(list(np.where(ys == classes[idx])[0]), k=1),
axis=0,
)
enc.append(lat)
enc.append(lat.numpy())

enc = np.asarray(enc)
alpha_values = torch.linspace(0, 1, num_steps)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train_eval_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_model_nopose(self):
self.assertEqual(n_plots_train, 28)
self.assertEqual(n_latent_train, 2)
self.assertEqual(n_states_train, 2)
self.assertEqual(n_plots_eval, 45)
self.assertEqual(n_plots_eval, 44)
self.assertEqual(n_latent_eval, 4)
self.assertEqual(n_states_eval, 3)

Expand Down
144 changes: 45 additions & 99 deletions tools/plotting_post_process.ipynb

Large diffs are not rendered by default.