Skip to content

Commit bca2c13

Browse files
committed
Fixed bug in KL_loss calculation for VAE validation step during training
1 parent a51fdeb commit bca2c13

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

generation/maisi/maisi_train_vae_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@
692692
},
693693
{
694694
"cell_type": "code",
695-
"execution_count": 14,
695+
"execution_count": null,
696696
"id": "4c251a32-390f-46dd-a613-75b12a7884c1",
697697
"metadata": {
698698
"scrolled": true
@@ -850,7 +850,7 @@
850850
" with torch.no_grad():\n",
851851
" with autocast(\"cuda\", enabled=args.amp):\n",
852852
" images = batch[\"image\"]\n",
853-
" reconstruction, _, _ = dynamic_infer(val_inferer, autoencoder, images)\n",
853+
" reconstruction, z_mu, z_sigma = dynamic_infer(val_inferer, autoencoder, images)\n",
854854
" reconstruction = reconstruction.to(device)\n",
855855
" val_epoch_losses[\"recons_loss\"] += intensity_loss(reconstruction, images.to(device)).item()\n",
856856
" val_epoch_losses[\"kl_loss\"] += KL_loss(z_mu, z_sigma).item()\n",

0 commit comments

Comments
 (0)