Skip to content

Commit cd75fd0

Browse files
committed
Fixed generate samples script
1 parent 3beca41 commit cd75fd0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

evaluation/visualization/generate_samples.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
import click
55
import torch
6-
from torchvision.transforms import transforms
6+
from torchvision.utils import save_image
77

88
from utils.setup_utils import get_device, load_yaml_config
99
from utils.training_utils import load_vae_architecture
1010

1111

1212
def _save_sample(sample: torch.Tensor, vae_name: str):
13-
sample_image = transforms.ToPILImage()(sample[0])
14-
sample_image.save(f"sample_{vae_name}.png")
13+
save_image(sample[0], f"sample_{vae_name}.png")
1514

1615

1716
@click.command()
@@ -39,7 +38,8 @@ def main(vae_dirs: List[str], gpu: int, best_vae: bool, seed: int):
3938
vae_version = specific_vae_dir.split("version_")[-1]
4039
vae_version = "v_" + vae_version
4140

42-
sample = vae.decode(random_latent_vector)
41+
with torch.no_grad():
42+
sample = vae.decode(random_latent_vector)
4343

4444
_save_sample(sample, vae_version)
4545

0 commit comments

Comments
 (0)