Skip to content

Commit

Permalink
Adapt Examples to new api
Browse files Browse the repository at this point in the history
- Computing CVs Example
- Contrastive Loss Example
- Using data recorder Example
  • Loading branch information
knikolaou committed May 16, 2024
1 parent 45aea40 commit afe6d1a
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 115 deletions.
74 changes: 48 additions & 26 deletions examples/Computing-Collective-Variables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
"from neural_tangents import stax\n",
"import optax\n",
"\n",
"from papyrus.measurements import (\n",
" Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy, NTKEigenvalues\n",
")\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -184,7 +188,7 @@
"- NTK\n",
"- NTK Eigenvalues\n",
"- Entropy of the NTK\n",
"- Magnitude Variance of the NTK\n",
"- Self-Entropy of the NTK\n",
"- Trace of the NTK\n",
"- Frobenius norm of the Loss Derivative\n",
"\n",
Expand All @@ -200,28 +204,39 @@
"source": [
"train_recorder = nl.training_recording.JaxRecorder(\n",
" name=\"train_recorder\",\n",
" loss=True,\n",
" ntk=True,\n",
" covariance_entropy=True,\n",
" eigenvalues=True,\n",
" magnitude_variance=True, \n",
" trace=True,\n",
" loss_derivative=True,\n",
" update_rate=1\n",
" measurements=[\n",
" Loss(name=\"loss\", apply_fn=nl.loss_functions.LPNormLoss(order=2)),\n",
" Accuracy(name=\"accuracy\", apply_fn=nl.accuracy_functions.LabelAccuracy()),\n",
" NTKTrace(name=\"ntk_trace\"),\n",
" NTKEntropy(name=\"ntk_entropy\"),\n",
" NTK(name=\"ntk\"),\n",
" NTKSelfEntropy(name=\"ntk_self_entropy\"),\n",
" NTKEigenvalues(name=\"ntk_eigenvalues\"),\n",
" ],\n",
" storage_path=\".\",\n",
" update_rate=1, \n",
" chunk_size=1e5\n",
")\n",
"train_recorder.instantiate_recorder(\n",
" data_set=data_generator.train_ds, \n",
" ntk_computation=ntk_computation\n",
" ntk_computation=ntk_computation, \n",
" model=fuel_model\n",
")\n",
"\n",
"\n",
"test_recorder = nl.training_recording.JaxRecorder(\n",
" name=\"test_recorder\",\n",
" loss=True,\n",
" update_rate=1\n",
" measurements=[\n",
" Loss(name=\"loss\", apply_fn=nl.loss_functions.LPNormLoss(order=2)),\n",
" Accuracy(name=\"accuracy\", apply_fn=nl.accuracy_functions.LabelAccuracy()),\n",
" ],\n",
" storage_path=\".\",\n",
" update_rate=1, \n",
" chunk_size=1e5\n",
")\n",
"test_recorder.instantiate_recorder(\n",
" data_set=data_generator.test_ds\n",
" data_set=data_generator.test_ds, \n",
" model=fuel_model\n",
")"
]
},
Expand Down Expand Up @@ -294,8 +309,8 @@
"metadata": {},
"outputs": [],
"source": [
"train_report = train_recorder.gather_recording()\n",
"test_report = test_recorder.gather_recording()"
"train_report = train_recorder.gather()\n",
"test_report = test_recorder.gather()"
]
},
{
Expand All @@ -305,8 +320,8 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.loss, 'o', mfc='None', label=\"Train\")\n",
"plt.plot(test_report.loss, 'o', mfc='None', label=\"Train\")\n",
"plt.plot(train_report[\"loss\"], 'o', mfc='None', label=\"Train\")\n",
"plt.plot(test_report[\"loss\"], 'o', mfc='None', label=\"Test\")\n",
"\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
Expand All @@ -322,7 +337,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.covariance_entropy, 'o', mfc='None', label=\"Entropy\")\n",
"plt.plot(train_report['ntk_entropy'], 'o', mfc='None', label=\"Entropy\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Entropy\")\n",
"plt.legend()\n",
Expand All @@ -336,7 +351,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.magnitude_variance, 'o', mfc='None', label=\"Magnitude Variance\")\n",
"plt.plot(train_report['ntk_self_entropy'], 'o', mfc='None', label=\"Magnitude Variance\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Magnitude Variance\")\n",
"plt.legend()\n",
Expand All @@ -346,25 +361,32 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0d3813e3",
"id": "6d43257c-defc-4f1e-816a-ebe1ae79e7ca",
"metadata": {},
"outputs": [],
"source": [
"train_report.eigenvalues.shape"
"plt.plot(train_report['ntk_trace'], 'o', mfc='None', label=\"Trace\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Trace\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d43257c-defc-4f1e-816a-ebe1ae79e7ca",
"id": "2ff7f726",
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.trace, 'o', mfc='None', label=\"Trace\")\n",
"plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,0], 'o', mfc='None', label=\"Largest EV\")\n",
"plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,1], 'o', mfc='None', label=\"2nd Largest EV\")\n",
"plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,2], 'o', mfc='None', label=\"3rd Largest EV\")\n",
"plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,3], 'o', mfc='None', label=\"4th Largest EV\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Trace\")\n",
"plt.legend()\n",
"plt.show()"
"plt.ylabel(\"Eigenvalues\")\n",
"plt.yscale(\"log\")\n",
"plt.legend()"
]
},
{
Expand Down
Loading

0 comments on commit afe6d1a

Please sign in to comment.