Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_reference/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ The {mod}`frouros.utils` module contains auxiliary classes, functions or excepti
utils/checks
utils/data_structures
utils/kernels
utils/persistence
utils/stats
```
9 changes: 9 additions & 0 deletions docs/source/api_reference/utils/persistence.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Persistence

The {mod}`frouros.utils.persistence` module contains auxiliary functions to persistence objects.

```{eval-rst}
.. automodule:: frouros.utils.persistence
:members:
:no-inherited-members:
```
1 change: 1 addition & 0 deletions docs/source/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@

examples/concept_drift
examples/data_drift
examples/utils
```
7 changes: 7 additions & 0 deletions docs/source/examples/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Utils

```{toctree}
:maxdepth: 1

utils/save_load
```
359 changes: 359 additions & 0 deletions docs/source/examples/utils/save_load.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:36.559538Z",
"start_time": "2024-03-02T20:08:35.785936Z"
}
},
"outputs": [],
"source": [
"from functools import partial\n",
"import numpy as np\n",
"from scipy.spatial.distance import pdist\n",
"\n",
"from frouros.callbacks import PermutationTestDistanceBased\n",
"from frouros.detectors.data_drift import MMD\n",
"from frouros.utils import load, save\n",
"from frouros.utils.kernels import rbf_kernel"
]
},
{
"cell_type": "markdown",
"source": [
"# Save and Load detector\n",
"\n",
"In this example, we will demonstrate how to save and load a detector. We will use the MMD detector and the permutation test callback. We will first fit the detector and then compare two datasets. We will then save the detector to a file and load it back. We will then compare the same two datasets and assert that the distance and p-value are the same before and after saving and loading the detector."
],
"metadata": {
"collapsed": false
},
"id": "e3f1ddf0540a9259"
},
{
"cell_type": "markdown",
"source": [
"## Set random seed\n",
"\n",
"We will set the random seed to ensure reproducibility."
],
"metadata": {
"collapsed": false
},
"id": "4df73e55d7d353bb"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"seed = 31\n",
"np.random.seed(seed)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:36.567956Z",
"start_time": "2024-03-02T20:08:36.561066Z"
}
},
"id": "f913c4fc44d511f7",
"execution_count": 2
},
{
"cell_type": "markdown",
"source": [
"## Generate data\n",
"\n",
"We will generate two datasets. The first dataset will be generated from a multivariate normal distribution with mean [0, 0] and covariance matrix [[1, 0], [0, 1]]. The second dataset will be generated from a multivariate normal distribution with mean [1, 0] and covariance matrix [[1, 0], [0, 2]]."
],
"metadata": {
"collapsed": false
},
"id": "b08089f5ccf0f4d1"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"num_samples = 100\n",
"\n",
"x_mean = [0, 0]\n",
"x_cov = [\n",
" [1, 0],\n",
" [0, 1],\n",
"]\n",
"\n",
"y_mean = [1, 0]\n",
"y_cov = [\n",
" [1, 0],\n",
" [0, 2],\n",
"]\n",
"\n",
"X_ref = np.random.multivariate_normal(\n",
" mean=x_mean,\n",
" cov=x_cov,\n",
" size=num_samples,\n",
")\n",
"X_test = np.random.multivariate_normal(\n",
" mean=y_mean,\n",
" cov=y_cov,\n",
" size=num_samples,\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:36.583840Z",
"start_time": "2024-03-02T20:08:36.570122Z"
}
},
"id": "188b82ee45c1a092",
"execution_count": 3
},
{
"cell_type": "markdown",
"source": [
"## Fit detector\n",
"\n",
"We will fit the detector using the reference dataset."
],
"metadata": {
"collapsed": false
},
"id": "dd7dd35a96e1651a"
},
{
"cell_type": "code",
"outputs": [
{
"data": {
"text/plain": "1.5941478725484344"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sigma = np.median(\n",
" pdist(\n",
" X=X_ref,\n",
" metric=\"euclidean\",\n",
" ),\n",
" )\n",
"sigma"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:36.599907Z",
"start_time": "2024-03-02T20:08:36.584853Z"
}
},
"id": "23fac866bcd656ee",
"execution_count": 4
},
{
"cell_type": "code",
"outputs": [],
"source": [
"detector = MMD(\n",
" kernel=partial(\n",
" rbf_kernel,\n",
" sigma=sigma,\n",
" ),\n",
" callbacks=PermutationTestDistanceBased(\n",
" num_permutations=100,\n",
" num_jobs=-1,\n",
" method=\"exact\",\n",
" random_state=seed,\n",
" name=\"permutation_test\",\n",
" ),\n",
")\n",
"\n",
"_ = detector.fit(\n",
" X=X_ref,\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:36.615923Z",
"start_time": "2024-03-02T20:08:36.603076Z"
}
},
"id": "3bf7b070454ba708",
"execution_count": 5
},
{
"cell_type": "markdown",
"source": [
"## Compare datasets before saving\n",
"\n",
"We will compare the reference and test datasets."
],
"metadata": {
"collapsed": false
},
"id": "ca0bee617c055e14"
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Distance: 0.14644993, p-value: 0.00990049\n"
]
}
],
"source": [
"distance, callback_logs = detector.compare(\n",
" X=X_test,\n",
")\n",
"before_save_distance = distance.distance\n",
"before_save_p_value = callback_logs['permutation_test']['p_value']\n",
"print(f\"Distance: {before_save_distance:.8f}, p-value: {before_save_p_value:.8f}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:39.021802Z",
"start_time": "2024-03-02T20:08:36.616944Z"
}
},
"id": "c1f670b30658a751",
"execution_count": 6
},
{
"cell_type": "markdown",
"source": [
"## Save and Load detector\n",
"\n",
"We will save the detector to a file and load it back."
],
"metadata": {
"collapsed": false
},
"id": "4dad43da2f94c1ec"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"save(\n",
" obj=detector,\n",
" filename=\"detector.pkl\",\n",
")\n",
"\n",
"detector = load(\n",
" filename=\"detector.pkl\",\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:39.037744Z",
"start_time": "2024-03-02T20:08:39.024229Z"
}
},
"id": "d0aa212a9e91de5c",
"execution_count": 7
},
{
"cell_type": "markdown",
"source": [
"## Compare datasets after loading\n",
"\n",
"We will compare the reference and test datasets again."
],
"metadata": {
"collapsed": false
},
"id": "97d354f3aaf7f555"
},
{
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Distance: 0.14644993, p-value: 0.00990049\n"
]
}
],
"source": [
"distance, callback_logs = detector.compare(\n",
" X=X_test,\n",
")\n",
"after_save_distance = distance.distance\n",
"after_save_p_value = callback_logs['permutation_test']['p_value']\n",
"print(f\"Distance: {after_save_distance:.8f}, p-value: {after_save_p_value:.8f}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:41.628646Z",
"start_time": "2024-03-02T20:08:39.038798Z"
}
},
"id": "a681537ba868af6b",
"execution_count": 8
},
{
"cell_type": "markdown",
"source": [
"Assert that the distance and p-value are the same before and after saving and loading the detector."
],
"metadata": {
"collapsed": false
},
"id": "3a81841ec13cc881"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"assert before_save_distance == after_save_distance\n",
"assert before_save_p_value == after_save_p_value"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-02T20:08:41.644471Z",
"start_time": "2024-03-02T20:08:41.629678Z"
}
},
"id": "1a7e98cb985f2e5b",
"execution_count": 9
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading