Skip to content

Commit

Permalink
suppress ell warning with no_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-dowling committed Oct 10, 2022
1 parent 4efae57 commit f64a72b
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions 03_variational_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 86,
"id": "6444dc71-698d-4959-b5e9-0b28934e2af0",
"metadata": {},
"outputs": [],
Expand All @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 87,
"id": "e4b6a97a-a6b6-4a62-8edb-f377ec025cd6",
"metadata": {},
"outputs": [],
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 88,
"id": "c5914fa1-ca6e-4045-8eeb-014fd2f6a968",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 89,
"id": "d8142194-ba57-4f56-aba9-d26a49af64ff",
"metadata": {},
"outputs": [],
Expand All @@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 90,
"id": "e60eded1-1f6e-4bda-a13c-99c1a359bc7e",
"metadata": {},
"outputs": [],
Expand All @@ -138,7 +138,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 91,
"id": "95341dae-1134-47b0-be79-dab5f6ebfcde",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -177,7 +177,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 92,
"id": "1f9008ab-5f73-4590-9793-f75f969d12e1",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -245,23 +245,25 @@
"execution_count": null,
"outputs": [],
"source": [
"_, z_sample, X_hat_VAE, log_P_hat_VAE = vae(Y, Y, 1.0)\n",
"with torch.no_grad():\n",
" _, z_sample, X_hat_VAE, log_P_hat_VAE = vae(Y, Y, 1.0)\n",
"P_hat_VAE = torch.exp(log_P_hat_VAE)\n",
"\n",
"z_sample = rearrange(z_sample, 'time trial latent -> trial time latent')\n",
"X_hat_VAE = rearrange(X_hat_VAE, 'time trial latent -> trial time latent')\n",
"P_hat_VAE = rearrange(X_hat_VAE, 'time trial latent -> trial time latent')\n",
"P_hat_VAE = rearrange(P_hat_VAE, 'time trial latent -> trial time latent')\n",
"plt.plot(z_sample[0, :, 0].detach().numpy())\n",
"\n",
"C_VAE = utils.estimate_readout_matrix(Y, X_hat_VAE, None, time_delta, 500) # refit for fairness\n",
"ell_VAE = utils.expected_ll_poisson(Y, np.asarray(X_hat_VAE), P_hat_VAE, C_VAE, time_delta)\n",
"ell_VAE = utils.expected_ll_poisson(Y, X_hat_VAE, P_hat_VAE, C_VAE, time_delta, dtype=torch.float64)\n",
"\n",
"print(f'VAE ell: {ell_VAE: .3f}')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
"name": "#%%\n",
"is_executing": true
}
}
},
Expand Down

0 comments on commit f64a72b

Please sign in to comment.