Skip to content

Commit

Permalink
The Anderson-Darling test
Browse files Browse the repository at this point in the history
  • Loading branch information
Михаил Гущин authored and Михаил Гущин committed Apr 1, 2023
1 parent 04e1fba commit 6131e5c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
6 changes: 4 additions & 2 deletions probaforms/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .fd import frechet_distance
from .mmd import maximum_mean_discrepancy
from .ks1d import kolmogorov_smirnov_1d, cramer_von_mises_1d, roc_auc_score_1d
from .ks1d import kolmogorov_smirnov_1d, cramer_von_mises_1d
from .ks1d import roc_auc_score_1d, anderson_darling_1d
from .div1d import kullback_leibler_1d, kullback_leibler_1d_kde
from .div1d import jensen_shannon_1d, jensen_shannon_1d_kde

Expand All @@ -14,5 +15,6 @@
'kullback_leibler_1d_kde',
'jensen_shannon_1d_kde',
'cramer_von_mises_1d',
'roc_auc_score_1d'
'roc_auc_score_1d',
'anderson_darling_1d'
]
44 changes: 39 additions & 5 deletions probaforms/metrics/ks1d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import numpy as np
from sklearn.utils import resample
from scipy.stats import ks_2samp, cramervonmises_2samp
from scipy.stats import ks_2samp, cramervonmises_2samp, anderson_ksamp
from sklearn.metrics import roc_auc_score

import warnings
warnings.filterwarnings(action='ignore', category=UserWarning) # for _anderson1d



def _ks1d(data1, data2):
ks, _ = ks_2samp(data1, data2)
Expand All @@ -19,6 +23,10 @@ def _roc1d(x, y):
auc = np.abs(auc - 0.5) + 0.5
return auc

def _anderson1d(x, y):
res = anderson_ksamp([x, y])
return res.statistic


def _bootstrap_metric(metric_func, X_real, X_fake, n_iters=100, *args):
'''
Expand Down Expand Up @@ -64,7 +72,7 @@ def _bootstrap_metric(metric_func, X_real, X_fake, n_iters=100, *args):

def kolmogorov_smirnov_1d(X_real, X_fake, n_iters=100):
'''
Calculates the Kolmogorov Smirnov statistics for real and fake samples.
Calculates the Kolmogorov Smirnov statistic for real and fake samples.
The function calculates metric values for each input feature,
and then averaged them.
Expand All @@ -80,7 +88,7 @@ def kolmogorov_smirnov_1d(X_real, X_fake, n_iters=100):
Return:
-------
distance: float
The estimated KS statistics.
The estimated KS statistic.
Std: fload
The standard deviation of the distance.
'''
Expand All @@ -106,7 +114,7 @@ def cramer_von_mises_1d(X_real, X_fake, n_iters=100):
Return:
-------
distance: float
The estimated KS statistics.
The estimated CvM statistic.
Std: fload
The standard deviation of the distance.
'''
Expand All @@ -132,9 +140,35 @@ def roc_auc_score_1d(X_real, X_fake, n_iters=100):
Return:
-------
distance: float
The estimated KS statistics.
The estimated statistic.
Std: fload
The standard deviation of the distance.
'''

return _bootstrap_metric(_roc1d, X_real, X_fake, n_iters)


def anderson_darling_1d(X_real, X_fake, n_iters=100):
'''
Calculates the Anderson-Darling statistic for real and fake samples.
The function calculates metric values for each input feature,
and then averaged them.
Parameters:
-----------
X_real: numpy.ndarray of shape [n_samples, n_features]
Real sample.
X_fake: numpy.ndarray of shape [n_samples, n_features]
Generated sample.
n_iters: int
The number of bootstrap iterations. Default = 100.
Return:
-------
distance: float
The estimated statistic.
Std: fload
The standard deviation of the distance.
'''

return _bootstrap_metric(_anderson1d, X_real, X_fake, n_iters)

0 comments on commit 6131e5c

Please sign in to comment.