From 19cd864b3d2c79f8a3e2d72efa3c6b74a2f3dd03 Mon Sep 17 00:00:00 2001 From: stsouko Date: Thu, 4 Jun 2020 20:28:40 +0300 Subject: [PATCH] docs configured --- CIMtools/datasets/__init__.py | 5 +- CIMtools/metrics/__init__.py | 5 +- CIMtools/metrics/pairwise.py | 3 + CIMtools/model_selection/__init__.py | 5 +- CIMtools/preprocessing/__init__.py | 8 + README.md => README.rst | 50 +- doc/conf.py | 6 +- doc/datasets.rst | 8 + doc/index.rst | 24 +- doc/metrics.rst | 8 + doc/model_selection.rst | 8 + doc/preprocessing.rst | 8 + doc/standardize.rst | 8 + doc/tutorial/applicability_domain.ipynb | 1780 +++++++++++++++++++++++ environment.yml | 14 + readthedocs.yml | 11 +- requirements.txt | 5 - setup.py | 7 +- 18 files changed, 1894 insertions(+), 69 deletions(-) rename README.md => README.rst (69%) create mode 100644 doc/datasets.rst create mode 100644 doc/metrics.rst create mode 100644 doc/model_selection.rst create mode 100644 doc/preprocessing.rst create mode 100644 doc/standardize.rst create mode 100644 doc/tutorial/applicability_domain.ipynb create mode 100644 environment.yml delete mode 100644 requirements.txt diff --git a/CIMtools/datasets/__init__.py b/CIMtools/datasets/__init__.py index dec45b5..236926f 100644 --- a/CIMtools/datasets/__init__.py +++ b/CIMtools/datasets/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2018 Ramil Nugmanov +# Copyright 2018, 2020 Ramil Nugmanov # This file is part of CIMtools. # # CIMtools is free software; you can redistribute it and/or modify @@ -17,3 +17,6 @@ # along with this program; if not, see . # from .molconvert_chemaxon import * + + +__all__ = ['molconvert_chemaxon'] diff --git a/CIMtools/metrics/__init__.py b/CIMtools/metrics/__init__.py index 9c904d1..65c56d6 100644 --- a/CIMtools/metrics/__init__.py +++ b/CIMtools/metrics/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2018 Ramil Nugmanov +# Copyright 2018, 2020 Ramil Nugmanov # This file is part of CIMtools. # # CIMtools is free software; you can redistribute it and/or modify @@ -18,3 +18,6 @@ # from .pairwise import * from .applicability_domain_metrics import * + + +__all__ = ['balanced_accuracy_score_with_ad', 'rmse_score_with_ad', 'tanimoto_kernel'] diff --git a/CIMtools/metrics/pairwise.py b/CIMtools/metrics/pairwise.py index 20a5412..21e5204 100644 --- a/CIMtools/metrics/pairwise.py +++ b/CIMtools/metrics/pairwise.py @@ -20,6 +20,9 @@ def tanimoto_kernel(x, y): + """ + + """ x_dot = np.dot(x, y.T) x2 = (x**2).sum(axis=1) diff --git a/CIMtools/model_selection/__init__.py b/CIMtools/model_selection/__init__.py index e611425..63cb608 100644 --- a/CIMtools/model_selection/__init__.py +++ b/CIMtools/model_selection/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2018, 2019 Ramil Nugmanov +# Copyright 2018-2020 Ramil Nugmanov # This file is part of CIMtools. # # CIMtools is free software; you can redistribute it and/or modify @@ -19,3 +19,6 @@ from .group_out import * from .reaction_type_control_selection import * from .transformation_out import * + + +__all__ = ['LeaveOneGroupOut', 'TransformationOut', 'rtc_env_selection'] diff --git a/CIMtools/preprocessing/__init__.py b/CIMtools/preprocessing/__init__.py index 8b49431..f60a5ab 100644 --- a/CIMtools/preprocessing/__init__.py +++ b/CIMtools/preprocessing/__init__.py @@ -24,3 +24,11 @@ from .graph_to_matrix import * from .solvent import * from .standardize import * + + +__all__ = ['Conditions', 'DictToConditions', 'ConditionsToDataFrame', 'SolventVectorizer', 'EquationTransformer', + 'CGR', 'MoleculesToMatrix', 'CGRToMatrix'] + +if 'Fragmentor' in locals(): + __all__.append('Fragmentor') + __all__.append('FragmentorFingerprint') diff --git a/README.md b/README.rst similarity index 69% rename from README.md rename to README.rst index 3ba9e28..d9323c6 100644 --- a/README.md +++ b/README.rst @@ -10,78 +10,58 @@ INSTALL Linux Debian based ------------------ -* Install python3.7, virtualenv and git +* Install python3.7, virtualenv and git:: - ``` sudo apt install python3.7 python3.7-dev git python3-virtualenv - ``` -* Create new environment and activate it. +* Create new environment and activate it:: - ``` virtualenv -p python3.7 venv source venv/bin/activate - ``` Mac --- -* Install python3.7 and git using [brew]() +* Install python3.7 and git using :: - ``` brew install git brew install python3 - ``` - -* Install virtualenv. - ``` +* Install virtualenv:: + pip install virtualenv - ``` -* Create new environment and activate it. +* Create new environment and activate it:: - ``` virtualenv -p python3.7 venv source venv/bin/activate - ``` - + Windows ------- -* Install python3.7 and git using [Chocolatey]() +* Install python3.7 and git using :: - ``` choco install git choco install python3 - ``` -* Install virtualenv. +* Install virtualenv:: - ``` pip install virtualenv - ``` -* Create new environment and activate it. +* Create new environment and activate it:: - ``` virtualenv venv venv\Scripts\activate - ``` General part ------------ -* **stable version will be available through PyPI** +* **stable version will be available through PyPI**:: - ``` pip install CIMtools - ``` -* Install CGRtools library DEV version for features that are not well tested. Git lfs installation required https://git-lfs.github.com/. +* Install CGRtools library DEV version for features that are not well tested. Git lfs installation required :: - ``` pip install -U git+https://github.com/stsouko/CIMtools.git@master#egg=CIMtools - ``` **If you still have questions, please open issue within github.** @@ -89,7 +69,7 @@ SETUP ===== For ChemAxon standardizer used pyjnius. First of all install JDK (not JRE) OpenJDK or Oracle. -Some times it can't to find java installation properly. Just set environment variables. +Some times it can't to find java installation properly. Just set environment variables:: JAVA_HOME = '/path/to/dir/which/contain/bin/dir'. for example /usr/lib/jvm/java-11-openjdk-amd64 JVM_PATH = '/path/to/lib/server/libjvm.so'. For example '/usr/lib/jvm/java-11-openjdk-amd64/lib/server/libjvm.so' @@ -97,11 +77,11 @@ Some times it can't to find java installation properly. Just set environment var PACKAGING ========= -For wheel generation just type next command in source root +For wheel generation just type next command in source root:: python setup.py bdist_wheel -On Linux additionally do repairing of package +On Linux additionally do repairing of package:: pip install auditwheel auditwheel repair dist/CIMtools---linux_x86_64.whl diff --git a/doc/conf.py b/doc/conf.py index aa86ffc..6620846 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -11,7 +11,7 @@ project = 'CIMtools' needs_sphinx = '1.8' -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'numpydoc', 'm2r'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'numpydoc', 'nbsphinx'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] @@ -24,7 +24,7 @@ todo_include_todos = False autoclass_content = 'both' -html_theme_options = {'github_user': 'stsouko', 'github_repo': 'CIMtools', 'show_related': True} +html_theme_options = {'github_user': 'cimm_kzn', 'github_repo': 'CIMtools', 'show_related': True} html_show_copyright = True html_show_sourcelink = False html_sidebars = { @@ -35,3 +35,5 @@ 'searchbox.html', ] } + +nbsphinx_execute = 'never' diff --git a/doc/datasets.rst b/doc/datasets.rst new file mode 100644 index 0000000..604d53c --- /dev/null +++ b/doc/datasets.rst @@ -0,0 +1,8 @@ +CIMtools\.datasets package +========================== + +.. automodule:: CIMtools.datasets + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/doc/index.rst b/doc/index.rst index c44cd7b..479706c 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,25 +1,23 @@ -.. mdinclude:: ../README.md - -CIMtools package -================ - -.. automodule:: CIMtools - :members: - :undoc-members: - :show-inheritance: +.. include:: ../README.rst Subpackages ------------ +=========== .. toctree:: - :maxdepth: 4 + :maxdepth: 2 applicability_domain + model_selection + metrics + datasets + preprocessing + standardize Tutorial --------- +======== .. toctree:: - :maxdepth: 4 + :maxdepth: 1 tutorial/metric_constants + tutorial/applicability_domain.ipynb diff --git a/doc/metrics.rst b/doc/metrics.rst new file mode 100644 index 0000000..c40fc7f --- /dev/null +++ b/doc/metrics.rst @@ -0,0 +1,8 @@ +CIMtools\.metrics package +========================= + +.. automodule:: CIMtools.metrics + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/doc/model_selection.rst b/doc/model_selection.rst new file mode 100644 index 0000000..3e5a23a --- /dev/null +++ b/doc/model_selection.rst @@ -0,0 +1,8 @@ +CIMtools\.model_selection package +================================= + +.. automodule:: CIMtools.model_selection + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst new file mode 100644 index 0000000..b9aa7b2 --- /dev/null +++ b/doc/preprocessing.rst @@ -0,0 +1,8 @@ +CIMtools\.preprocessing package +=============================== + +.. automodule:: CIMtools.preprocessing + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/doc/standardize.rst b/doc/standardize.rst new file mode 100644 index 0000000..c812695 --- /dev/null +++ b/doc/standardize.rst @@ -0,0 +1,8 @@ +CIMtools\.preprocessing.standardize package +=========================================== + +.. automodule:: CIMtools.preprocessing.standardize + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/doc/tutorial/applicability_domain.ipynb b/doc/tutorial/applicability_domain.ipynb new file mode 100644 index 0000000..3f2e87c --- /dev/null +++ b/doc/tutorial/applicability_domain.ipynb @@ -0,0 +1,1780 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BENCHMARKING OF APPLICABILITY DOMAINS FOR MODELS PREDICTING PROPERTIES OF CHEMICAL REACTIONS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quantitative Reaction-Property model predicting the rate concats of reactions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nowadays, Quantitative Structure-Activity/Property Relationship (QSAR/QSPR) models are widely used for predicting properties of chemical compounds [[1], [2]]. The growing attention is attracted to chemical reactions as objects of QSAR/QSPR-like modelling [[3],[4],[5]]. Below, you will find a model (**QRPR, Quantitative Reaction-Property model**) predicting the quantitative characteristic (rate constant) of bimolecular nucleophilic substitution reactions.\n", + "\n", + "[1]: https://pubs.acs.org/doi/10.1021/jm4004285\n", + "[2]: https://www.sciencedirect.com/science/article/pii/B9780128015056000077?via%3Dihub\n", + "[3]: http://mr.crossref.org/iPage?doi=10.1070%2FRCR4746\n", + "[4]: https://pubs.acs.org/doi/abs/10.1021/acs.accounts.8b00087\n", + "[5]: https://onlinelibrary.wiley.com/doi/abs/10.1002/minf.201800104" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import StandardScaler, FunctionTransformer\n", + "from os import environ\n", + "from CIMtools.preprocessing import Fragmentor, CGR, EquationTransformer, SolventVectorizer\n", + "from CIMtools.preprocessing.conditions_container import DictToConditions, ConditionsToDataFrame\n", + "\n", + "def extract_meta(x):\n", + " return [y[0].meta for y in x]\n", + "\n", + "def x_generation(data_train, data_test):\n", + " environ[\"PATH\"]+=\":/home/assima/env/bin\"\n", + " features = ColumnTransformer([('temp', EquationTransformer('1/x'), ['temperature']),\n", + " ('solv', SolventVectorizer(), ['solvent.1']),\n", + " ('amount', 'passthrough', ['solvent_amount.1']),])\n", + " conditions = Pipeline([('meta', FunctionTransformer(extract_meta)),\n", + " ('cond', DictToConditions(solvents=('additive.1',), \n", + " temperature='temperature', \n", + " amounts=('amount.1',))),\n", + " ('desc', ConditionsToDataFrame()),\n", + " ('final', features)])\n", + " graph = Pipeline([('CGR', CGR()), \n", + " ('frg', Fragmentor(fragment_type=3, max_length=4, useformalcharge=True, version='2017')), \n", + " ('scaler', StandardScaler())]) # All descriptors were normalized to zero mean and unit variance.\n", + "\n", + " pp = ColumnTransformer([('cond', conditions, [0]),\n", + " ('graph', graph, [0])])\n", + " X_train = pp.fit_transform([[x] for x in data_train])\n", + " X_test = pp.transform([[x] for x in data_test])\n", + " return X_train, X_test" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "4830it [00:10, 442.36it/s]\n" + ] + } + ], + "source": [ + "from tqdm import tqdm\n", + "from CGRtools import RDFRead\n", + "\n", + "data = RDFRead('/home/assima/Assima_purple/home/Datasets/SN2_11_11_2019.rdf')\n", + "reactions = []\n", + "for n, r in tqdm(enumerate(data._data)):\n", + " r.kekule() # leads to kekula formula for benzene rings\n", + " r.implicify_hydrogens() # removes hydrogens\n", + " r.thiele() # aromatizes benzene rings\n", + " reactions.append(r)\n", + "\n", + "del data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the study, we use **Random Forest Regression (RFR)** and **Gaussian Process Regression (GPR)** for building QRPR models\n", + "\n", + "The number of trees in **RandomForestRegressor** is set to 500, while the only tuneable hyperparameter is the number of features selected upon tree branching (max_features). Other hyperparameters in **RandomForestRegressor** are set to default values. \n", + "\n", + "For **GPR** models, hyperparameters of noise level, alpha, and RBF kernel’s gamma values are adjusted. \n", + "\n", + "To obtain a reliable assessment of predictive performance and avoid overfitting, the nested cross-validation procedure [[6](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-7-91)] is used. For each training/test split in the outer loop, the hyperparameters of **RandomForestRegressor** and **Gaussian Process Regression** models are tuned using grid search by minimizing the averaged RMSE of prediction (without AD application) estimated in the inner cross-validation loop on the outer training set, and the optimal models with the tuned hyperparameters are used to predict reaction properties on the outer test set. \n", + "\n", + "For simplicity, we will break on the first fold" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 1 candidates, totalling 5 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 40 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 42.6s finished\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "from sklearn.gaussian_process import GaussianProcessRegressor\n", + "from sklearn.gaussian_process.kernels import RBF\n", + "from sklearn.model_selection import KFold, GridSearchCV\n", + "from sklearn.utils import safe_indexing\n", + "\n", + "kf = KFold(n_splits=5, random_state=1, shuffle=True) \n", + "\n", + "Y_true_rfr, Y_pred_rfr, Y_pred_gpr, Y_true_gpr = [], [], [], [] # collect all the values of logK\n", + "\n", + "for train_index_ext, test_index_ext in kf.split(reactions): # external set\n", + " reactions_train = safe_indexing(reactions, train_index_ext)\n", + " reactions_test = safe_indexing(reactions, test_index_ext)\n", + " X_train, X_test = x_generation(reactions_train, reactions_test) # fragment descriptors\n", + " Y_train = [float(x.meta['logK']) for x in reactions_train] # predictable property is the rate constants of the reactions\n", + " Y_test = [float(x.meta['logK']) for x in reactions_test] \n", + "\n", + " est = GridSearchCV(RandomForestRegressor(random_state=1, n_estimators=500),\n", + " {'max_features': [ None]},\n", + " cv=kf, verbose=1, scoring='neg_mean_squared_error', n_jobs=-1).fit(X_train, Y_train) # internal set \n", + " # [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 'auto', 'log2', None]\n", + " Y_pred = est.predict(X_test)\n", + " Y_pred_rfr.extend(Y_pred)\n", + " Y_true_rfr.extend(Y_test)\n", + " \n", + " # All descriptors were normalized to zero mean and unit variance. \n", + " # In the case of GPR models, both descriptors and property values were normalized, \n", + " # because this provided better predictive performance\n", + " \n", + " scaler = StandardScaler()\n", + " y_train_gpr = scaler.fit_transform(np.array(Y_train).reshape(-1, 1))\n", + " y_test_gpr = scaler.transform(np.array(Y_test).reshape(-1, 1))\n", + " Y_true_gpr.append(y_test_gpr)\n", + " param_grid = {'kernel': [RBF(1e-6)], 'alpha': [1e-1]} \n", + " # [RBF(1e-6), RBF(1e-5), RBF(1e-4), RBF(1e-3), RBF(1e-2), RBF(1e-1), RBF(1), RBF(10), RBF(100), RBF(1000), RBF(10000)]\n", + " gpr_grid = GridSearchCV(GaussianProcessRegressor(random_state=1), param_grid=param_grid, cv=kf,\n", + " scoring='neg_mean_squared_error', verbose=0, n_jobs=1).fit(X_train, y_train_gpr) # internal set\n", + " Y_pred_GPR, Y_var = gpr_grid.best_estimator_.predict(X_test, return_std=True)\n", + " Y_pred_gpr.append(Y_pred_GPR)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How we can assess the reliability of the model's predictions?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "QSAR/QSPR models are not universal, and their predictive performance highly depends on similarity to training examples [[7](https://www.researchgate.net/publication/7498882_QSAR_applicabilty_domain_estimation_by_projection_of_the_training_set_descriptor_space_a_review), [8](https://www.researchgate.net/publication/7583123_Current_Status_of_Methods_for_Defining_the_Applicability_Domain_of_Quantitative_Structure-Activity_Relationships-The_Report_and_Recommendations_of_ECVAM_Workshop_52)]. Applicability Domain (AD) of a QSAR/QSPR model highlights a part of the chemical space containing those compounds for which the model is supposed to provide reliable predictions [[9](https://pubs.acs.org/doi/abs/10.1021/ci800151m), [10](https://pubs.acs.org/doi/10.1021/ci100253r)]. So, the problem of determining AD of a model is closely related to the problem of assessing the reliability of its predictions. According to the OECD (Organization of Economic Co-operation and Development) principles, QSAR/QSPR models should have “defined an applicability domain” [[11](https://www.oecd-ilibrary.org/environment/guidance-document-for-the-development-of-oecd-guidelines-for-testing-of-chemicals_9789264077928-en)]. \n", + "\n", + "Although numerous approaches are considered in the literature to assess the AD for the models predicting the properties of chemical compounds, ADs have almost never been applied to the models predicting characteristics of chemical reactions, and the problem of AD definition for chemical reactions has never been discussed in the literature. \n", + "\n", + "It is much more difficult to define AD for the models aimed at predicting different characteristics of chemical \n", + "reactions in comparison with standard QSAR/QSPR models dealing with the properties of chemical compounds because it is necessary to consider several additional factors (reaction representation, conditions, reaction type, atom-to-atom mapping, etc) that are specific for chemical reactions and should be taken into account. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Therefore, in this notebook, we will show the various AD definition methods extensively used in QSAR/QSPR studies, their modifications, as well as novel approaches designed by us for reactions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Approaches for defining applicability domain of QRPR model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "AD definition approaches are considered as binary classifiers returning True for X-inliers (within AD) and False \n", + "for X-outliers (outside AD). In this work, AD definition approaches are conditionally divided into two groups: \n", + "(1) universal and (2) ML-dependent. For **Universal AD definition approaches**, only the Random Forest Regression (RFR) was used for building QRPR models (see above, *4 cell*). For **ML-dependent AD definition approaches**, \n", + "both Random Forest Regression and Gaussian Process Regression machine learning methods were used for this \n", + "purpose (sea below, )." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Universal applicability domain definition approaches" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Universal AD definition methods can be used on top of QRPR models, which can be implemented by any suitable machine learning method. For example, above, for implementing QRPR models we used the Random Forest Regressor. A part of them (**Bounding Box**, **Fragment Control**, **Reaction Type Control**) gives an answer whether a test object is within AD or not. When using **Bounding Box** and **Fragment Control** needn't choose a value for any adjustable parameter. When using **Reaction Type Control**, it is usually necessary to choose a value of adjustable hyperparameter. Such methods correspond to the applicability aspect of AD definition according to Hanser et al. [12](https://www.tandfonline.com/doi/full/10.1080/1062936X.2016.1250229).\n", + "\n", + "For AD definition approaches which do not require hyperparameters selection, a regression QRPR model and an AD definition model were built on the external training set selected in each split of external cross-validation and both the property and applicability domain membership (within AD or out of AD) for the external test set were predicted. See below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Fragment Control (FC)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, if a Condensed Graph of Reaction representing a given reaction has fragments (subgraphs) missing in the training set, then it is considered to be an X-outlier (out of AD) whenever the corresponding QRPR model is applied. **Fragment Control** can formally be considered as a special case of **Bounding Box** for fragment descriptors. This method does not have adjustable parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "AD_FC = Pipeline([('cgr', CGR()), \n", + " ('frg', Fragmentor(version='2017', \n", + " max_length=4, \n", + " useformalcharge=True,\n", + " return_domain=True))]).fit(reactions_train).transform(reactions_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_FC.AD.values [:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Bounding box" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This approach defines AD as a D-dimensional hypercube with each edge spanning the range between the minimum and maximum values of the corresponding descriptor. If at least one descriptor for a given reaction is out of the range defined by the minimum and maximum values of the training set examples, the reaction is considered outside of the AD of the corresponding QRPR model. The method does not have adjustable hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from CIMtools.applicability_domain import Box\n", + "\n", + "AD_BB = Box().fit(X_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_BB[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To build the best AD definition models, it is necessary to optimize their thresholds and (hyper)parameters\n", + "by maximizing some AD performance metric. We will use the following metrics:\n", + "\n", + "(i) OIR (Out and in RMSE) - is the difference between RMSE of property prediction for reactions outside AD (denoted as RMSEOUT) and within AD (denoted as RMSEIN). The metric was first proposed by Sahigata et al [13](https://europepmc.org/article/pmc/pmc6268288). \n", + "\n", + "(ii) The Outliers Criterion metric shows how well AD definition detects Y-outliers. First, property prediction errors are estimated in cross-validation for all reactions in a dataset. The reactions for which the absolute prediction error is higher than 3×RMSE are identified as Y-outliers, while the rest are considered as Y-inliers. Y-Outliers (poorly predicted) that are predicted by AD definition as X-outliers (outside AD) are called true outliers (TO), while Y-inliers predicted by AD definition as X-inliers (within AD) are called true inliers (TI). False outliers (FO) are Y-inliers that are wrongly predicted by the AD definition as X-outliers, while false inliers (FI) are Y-outliers that are wrongly predicted by the AD definition as X-inliers. Then quality of outliers/inliers determination can be assessed using an analogue of the balanced accuracy and denoted as OD (Outliers Detection). \n", + "OD = (TO/(TO+FI)+TI/(TI+FO))/2.\n", + "\n", + "They are implemented in CIMtools.metrics.applicability_domain_metrics as well as slightly modified functions themselves are shown below. These functions will be needed later when working with 1-SVM" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import mean_squared_error, balanced_accuracy_score\n", + "\n", + "def balanced_accuracy_score_with_ad(Y_pred_test, AD_pred):\n", + " AD_true = abs(Y_pred_test[:, 0] - Y_pred_test[:, 1]) <= 3 * np.sqrt(mean_squared_error(Y_pred_test[:, 1],\n", + " Y_pred_test[:, 0]))\n", + " return balanced_accuracy_score(AD_true, AD_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def rmse_score_with_ad(Y_pred_test, AD_pred):\n", + " AD_out_n = ~AD_pred\n", + " s_n = sum(AD_pred)\n", + " s_out_n = sum(AD_out_n)\n", + " if s_n:\n", + " RMSE_AD = np.sqrt((sum(map(lambda x: (((x[0] - x[1]) ** 2) * x[2]), zip(Y_pred_test[:, 0], Y_pred_test[:, 1], AD_pred)))) / s_n)\n", + " else:\n", + " RMSE_AD = 0\n", + "\n", + " if s_out_n:\n", + " RMSE_AD_out_n = np.sqrt((sum(map(lambda x: (((x[0] - x[1]) ** 2) * x[2]), zip(Y_pred_test[:, 0], Y_pred_test[:, 1], AD_out_n)))) / s_out_n)\n", + " else:\n", + " RMSE_AD_out_n = 0 \n", + " return RMSE_AD_out_n - RMSE_AD" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### The procedure for selecting the hyperparameters of QRPR and the AD definition models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since hyperparameters of some AD definition methods need to be tuned, the above-mentioned nested cross-validation procedure was used, in which the inner 5-fold cross-validation loop was used for hyperparameter selection, whereas the outer 5-fold cross-validation loop was used for assessing predictive performance.\n", + "All hyperparameters of both QRPR and AD definition models were selected for each training/test split in the outer cross-validation loop by maximizing the OIR or OD metrics computed for the outer training set using the inner cross-validation loop. The selected hyperparameters were used to rebuild models on the outer training set, which were further used to predict reaction property and AD membership (within AD or out of AD) on the corresponding outer test set. After the completion of the outer loop, all values (predicted value and AD membership) predicted on individual outer test sets were merged, and the predictive performances of QRPR models with AD were assessed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Reaction Type Control" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the reaction centre of a chemical reaction is absent in the reactions in the training set, it is considered out of AD (X-outlier). Reaction centre is detected using reaction signatures [14](https://pubs.acs.org/doi/10.1021/acs.jcim.9b00102). Signature creation includes (1) representation of a chemical reaction as a **Condensed Graph of Reaction**, (2) highlight one or more reaction centers which are identified as a set of adjacent dynamic atoms and bonds on the CGR and (3) environment atoms of a certain radius R for each of the reaction centers, (4) introducing canonical numbering of atoms of the reaction center with the environment using an algorithm similar to the Morgan algorithm, (5) the signature is encoded by SMILES-like canonical string generated by CGRtools library [14](https://pubs.acs.org/doi/10.1021/acs.jcim.9b00102). For every atom hybridization and element label, as well as bond order is encoded in a signature. Since the method does not consider the whole structure, but only its reaction center with its closest environment, in order to be able to distinguish whether the aromatic cycle is part of the reaction center or its closest substituent, we introduced a separate type of hybridization for aromatic carbon atoms. We also used sp3, sp2, sp hybridization to describe the hybridization of not aromatic atoms. Hence, the signature includes information both on the reaction centre itself and its closest environment of radius R. The radius of environment included into the signature is a hyperparameter of the method. If the environment is set to 0, the reaction signature includes only atoms of the reaction centre." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + } + ], + "source": [ + "from CIMtools.model_selection import rtc_env_selection\n", + "from CIMtools.applicability_domain import ReactionTypeControl\n", + "\n", + "# First find the optimal radius. For this reason, we need to use rtc_env_selection function.\n", + "# We must pass variables such as X_train, Y_train , best_model - to build a model, reactions_train - list of reactions,\n", + "# envs - list of radius values and set the metric by which the hyperparameter will be optimized\n", + "\n", + "score = 'ba_ad' # or 'rmse_ad'\n", + "best_r = rtc_env_selection(X=X_train, y=Y_train, data=reactions_train, envs=[0, 1], \n", + " reg_model=est.best_estimator_, \n", + " score=score) # as you can see we use only the training set for optimize their thresholds and (hyper)parameters\n", + "# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n", + "AD_RTC_cv = ReactionTypeControl(env=best_r).fit(reactions_train).predict(reactions_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_r" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_RTC_cv[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another part of them Universal AD definition methods (**Leverage**, **Nearest Neighbors approach**, **One-Class SVM**, **Two-Class X-inlier/Y-outlier classifier**) return a continuous value indicating the reliability of prediction. When using these methods, it is usually necessary to choose a threshold for such a value, and for some of them, the values of adjustable hyperparameters. Such methods correspond to the reliability aspect of AD definition according to Hanser et al. [12](https://www.tandfonline.com/doi/full/10.1080/1062936X.2016.1250229)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Leverage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This method is based on the Mahalanobis distance to the centre of the training-set distribution. The leverage h of a chemical reaction is calculated based on the “hat” matrix as h=(xiT(XTX)-1xi), where X is the training-set descriptor matrix, and xi is the descriptor vector for the reaction i. The leverage threshold is usually defined as \n", + "h*=3*(M+1)/N, where M is the number of descriptors and N is the number of training examples. Chemical reactions with leverage values h > h* are considered to be chemically different from the training-set reactions, so they are marked as X-outliers. This approach is denoted hereafter as Leverage. " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from CIMtools.applicability_domain import Leverage\n", + "\n", + "AD_Leverage = Leverage(threshold='auto').fit(X_train, Y_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_Leverage[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([4.15652350e-02, 1.98145197e-02, 1.46696292e-02, 1.48282359e-02,\n", + " 1.52431107e-02, 1.23464406e-01, 5.20473763e-02, 5.74329729e-02,\n", + " 4.20180764e-02, 4.20079504e-02, 2.30468541e-02, 2.18049682e-02,\n", + " 2.86753369e-02, 2.44993317e-02, 5.09700774e-02, 9.61464732e-02,\n", + " 2.27451111e-01, 1.63327780e-02, 7.92303786e-02, 1.16420112e-01,\n", + " 4.58887340e-02, 3.59907720e-02, 4.07309659e-02, 4.43190571e-02,\n", + " 2.61784426e-01, 2.13167612e-01, 2.04666722e-01, 3.25919047e-02,\n", + " 1.16788774e-01, 1.17492578e-01, 1.16813131e-01, 8.22376478e-02,\n", + " 1.15253605e-02, 3.38165164e-02, 1.02515342e-01, 7.41550829e-02,\n", + " 1.53142791e-01, 5.01867128e-02, 5.39678368e-01, 7.58144403e-02,\n", + " 2.16744657e-01, 2.40962855e+09, 4.92783114e-02, 4.98929354e-02,\n", + " 2.36371597e-02, 2.37591924e-02, 3.35730464e-02, 3.39811253e-02,\n", + " 1.47671365e-02, 3.33357480e-01, 2.61120053e-02, 2.61477126e-02,\n", + " 8.88706593e-02, 1.14265996e-02, 1.21605609e-02, 1.21389393e-02,\n", + " 1.15915288e-02, 1.19598376e-02, 1.20026463e-02, 1.36505423e-01,\n", + " 2.69751990e-02, 2.67441108e-02, 4.01787224e-01, 6.05392499e-02,\n", + " 2.98150671e-01, 5.50863345e-02, 1.92325215e-02, 3.71916360e-02,\n", + " 5.17907355e-02, 9.46831603e-03, 9.35649488e-03, 8.81575832e-03,\n", + " 2.07010510e-02, 1.24832906e-02, 1.23683594e-02, 3.12714698e-01,\n", + " 2.02364661e-02, 2.09982894e-02, 1.29552232e-02, 4.47147420e-02,\n", + " 5.36589637e-02, 4.46614369e-02, 2.26420368e-02, 2.38017458e-02,\n", + " 1.00029647e+00, 8.58891479e-02, 8.55562342e-02, 2.48593586e-02,\n", + " 1.99603546e-01, 2.25847653e-01, 2.06017366e-01, 1.56636386e-02,\n", + " 1.64352269e-02, 1.51838058e-02, 1.57291746e-02, 7.36311722e-02,\n", + " 2.50090892e-01, 1.76952864e-02, 1.12791182e-01, 1.13174140e-01,\n", + " 2.64876028e-02, 2.63118474e-02, 1.56781743e-01, 1.56378634e-01,\n", + " 1.54148723e-01, 3.05965447e-02, 9.85347887e-02, 1.46744295e-02,\n", + " 2.25422337e-02, 5.98015728e-02, 5.72764548e-02, 9.87482419e-03,\n", + " 3.80516726e+09, 3.80516726e+09, 6.81003689e-02, 6.62494124e-02,\n", + " 6.69356783e-02, 6.65759318e-02, 6.63458497e-02, 5.89670168e-02,\n", + " 2.33056680e-02, 7.16663007e-02, 2.64929551e-02, 4.25681363e-02,\n", + " 4.65502427e-02, 4.60749081e-02, 5.91445601e-02, 4.33783600e-02,\n", + " 1.58466001e+12, 1.58466001e+12, 2.07518658e-02, 8.63992181e-02,\n", + " 9.73247717e-02, 4.91108463e+01, 4.91297813e+01, 5.49226814e-01,\n", + " 2.11352582e-01, 1.51779056e-01, 4.57561987e-02, 6.85371916e-01,\n", + " 4.39195281e-02, 4.38802296e-02, 2.83659757e-02, 3.28411166e-02,\n", + " 1.27008116e-02, 2.27827436e-02, 7.61694282e-02, 1.66944233e-02,\n", + " 3.19759575e-01, 5.28255296e-02, 5.00206889e-01, 5.00369804e-01,\n", + " 1.00028367e+00, 1.59333842e-02, 1.58810175e-02, 2.24369360e-02,\n", + " 3.80078868e-02, 1.00029556e+00, 3.36907688e-01, 5.61527634e-02,\n", + " 8.11041475e-02, 5.43044434e-02, 5.36672316e-02, 9.35005751e-03,\n", + " 4.24916651e-02, 1.93860578e-02, 1.81468668e-02, 8.73697006e-02,\n", + " 5.73359586e-02, 5.02582442e-01, 1.51135460e-01, 3.17599846e-02,\n", + " 1.42987716e-02, 1.47720036e-02, 3.50342252e-02, 4.06470595e-02,\n", + " 1.20164241e+00, 8.07555560e-02, 1.08935011e-01, 1.10034413e-01,\n", + " 1.74817706e-02, 1.53748285e-02, 1.51690941e-02, 1.51161860e-02,\n", + " 5.86435388e-02, 1.87851616e-02, 1.85452095e-02, 2.58049398e-02,\n", + " 1.79247319e-02, 2.59277054e-02, 1.99265471e-02, 1.95020847e-02,\n", + " 2.50171317e-02, 1.72421003e-02, 2.55881771e-02, 1.74858135e-02,\n", + " 1.73765977e-02, 1.69532134e-02, 1.74963935e-02, 1.73497465e-02,\n", + " 1.80088054e-02, 2.09132005e-01, 1.55107672e-01, 4.67942935e-02,\n", + " 3.26555209e-02, 3.09920970e-02, 3.08620323e-02, 3.28037263e-02,\n", + " 8.70165942e-02, 8.70669236e-02, 2.12521409e-02, 2.20345605e-02,\n", + " 3.35766991e-02, 7.68233122e-02, 9.63174070e-02, 9.60952995e-02,\n", + " 1.93564691e-01, 1.94040174e-01, 1.17335099e-02, 2.88476035e-01,\n", + " 4.02119220e-02, 5.57151544e-02, 3.69187688e-02, 2.61910942e-02,\n", + " 1.00011559e+00, 1.13531911e-01, 1.41493561e-02, 3.36710915e-02,\n", + " 1.49552506e-01, 2.15278981e-02, 2.13992555e-02, 2.24245442e-02,\n", + " 2.13636870e-02, 1.31836050e-02, 9.33979898e-03, 1.14031315e-02,\n", + " 9.10079849e-02, 8.80691223e-02, 9.11401470e-02, 1.44207762e-02,\n", + " 1.47463684e-02, 3.62182691e-01, 1.14964822e-01, 3.30217738e-01,\n", + " 2.30379225e-02, 2.52030326e-02, 2.68156616e-02, 1.63886553e-01,\n", + " 5.58211843e-02, 5.69266462e-02, 2.88089469e+09, 1.58768984e-02,\n", + " 1.21321220e-02, 1.93127585e-02, 2.39899301e-02, 1.68041431e-02,\n", + " 3.34575810e-02, 1.23624711e-02, 2.14383929e-02, 1.20639802e-02,\n", + " 1.45835329e-02, 1.09110441e-02, 7.23988617e-02, 7.03286416e-02,\n", + " 2.27827436e-02, 3.35087928e-01, 5.55785279e-02, 6.98660565e-02,\n", + " 4.71816129e-02, 1.35553101e-01, 1.35566653e-01, 2.75481591e-02,\n", + " 9.15904955e-01, 4.21831181e-02, 1.58631400e-01, 1.04026783e-02,\n", + " 1.05293393e-02, 5.03871268e-02, 5.03093439e-02, 5.03340710e-02,\n", + " 8.22002764e-01, 4.03572578e-01, 3.64070170e-02, 4.45824360e-02,\n", + " 1.00082011e+00, 3.40340877e-01, 6.66987086e-01, 2.43594246e-02,\n", + " 2.48211047e-02, 3.21304668e-01, 1.19373558e-01, 4.93225718e-02,\n", + " 4.93024340e-02, 1.86013533e-01, 1.85268752e-01, 1.04705621e-01,\n", + " 8.62949808e-02, 2.66233526e-02, 2.58031318e-02, 7.20481728e-02,\n", + " 1.05728508e-01, 2.14072129e-02, 1.94336408e-01, 5.02473598e-02,\n", + " 5.00910886e-02, 4.05869890e-02, 8.62927381e-02, 2.64178586e-02,\n", + " 2.51006393e-02, 1.89782733e-02, 1.81515892e-02, 1.24711750e-02,\n", + " 1.26857378e-02, 1.85938446e-02, 1.27324598e-02, 1.31258465e-02,\n", + " 1.28602968e-02, 1.58111202e-02, 1.84543781e-02, 1.29729970e-02,\n", + " 1.85993748e-02, 1.26976630e-02, 1.24961660e-02, 4.65574814e-02,\n", + " 4.63648688e-02, 5.77883914e-02, 1.98138207e+00, 4.25016261e-02,\n", + " 1.35615791e-02, 1.61251907e-02, 2.67895076e-02, 1.74786719e-01,\n", + " 6.19829294e-02, 6.35934629e-02, 5.46047481e-02, 5.04383171e-02,\n", + " 5.06159150e-02, 5.03113412e-02, 1.89597140e-02, 2.12535376e-02,\n", + " 2.09086421e-02, 3.36907688e-01, 2.00250403e-02, 2.29128245e-02,\n", + " 1.08382574e-02, 1.07629715e-02, 1.78677478e-02, 1.82728296e-02,\n", + " 1.84128195e-02, 1.92946719e-02, 1.93213877e-02, 1.86915077e-02,\n", + " 1.81187582e-02, 1.81901659e-02, 2.83707399e-02, 6.64094104e-01,\n", + " 1.61820688e-02, 1.60029892e-02, 7.27600027e-02, 8.45070530e-02,\n", + " 8.48480366e-02, 5.22811317e-02, 5.18946873e-02, 2.17860938e-02,\n", + " 2.19619236e-02, 2.17779194e-02, 3.18473671e-02, 2.30094317e-02,\n", + " 2.19700466e-02, 2.83885723e-02, 1.42401844e-02, 1.69882325e-02,\n", + " 1.70562294e-02, 1.32078625e-01, 3.29184565e-02, 3.18986905e-02,\n", + " 5.01679925e-01, 6.75169747e-02, 2.87241687e-02, 2.97247853e-02,\n", + " 2.57761221e-02, 9.99800112e-01, 8.08397642e-02, 1.54228605e-02,\n", + " 8.35741598e-02, 2.58621125e-02, 4.96668499e-02, 1.63759745e-02,\n", + " 1.19851730e+00, 1.17385310e-02, 1.59639967e-02, 1.54028117e-02,\n", + " 2.13151185e-02, 1.17269252e-02, 1.31383632e-02, 1.02905527e-02,\n", + " 1.37725632e-02, 1.74641521e-02, 1.08461446e-02, 1.59134369e-02,\n", + " 3.90915590e-01, 5.09763421e-02, 4.13618774e-02, 2.22594098e-02,\n", + " 1.37048046e-01, 6.96404866e-01, 1.45687658e-02, 2.49165457e-02,\n", + " 1.62789297e-02, 1.55489722e-02, 2.60412179e-02, 2.64130321e-02,\n", + " 1.93674015e-02, 2.64721726e-02, 1.09991666e-02, 2.51506921e-02,\n", + " 1.60817702e-02, 2.62280030e-02, 1.27836367e-01, 1.31027670e-01,\n", + " 1.26256023e-01, 3.36796731e-02, 1.42273577e-02, 1.40748225e-02,\n", + " 1.42987716e-02, 1.49757399e-02, 1.73689367e-02, 2.60589963e-02,\n", + " 1.51274456e-02, 1.79004698e-02, 1.35402679e-02, 1.34889476e-02,\n", + " 6.38096869e-02, 6.14749581e-02, 8.27647549e-02, 8.26580870e-02,\n", + " 1.89875875e-02, 2.46796115e-02, 2.50684258e-02, 1.73493336e-02,\n", + " 1.59005293e-02, 3.63611475e-01, 2.15451030e-02, 3.99113510e-02,\n", + " 1.34569189e-01, 2.90244856e-01, 2.26172638e-02, 4.25776948e-02,\n", + " 1.81336506e-01, 1.35840040e-02, 1.39306803e-02, 1.60301306e-02,\n", + " 1.56552865e-02, 1.34767501e-02, 1.15754739e-01, 5.25401874e-02,\n", + " 1.48205104e-02, 1.06473059e-02, 1.69517782e-02, 1.73843039e-02,\n", + " 2.50048817e-02, 1.24880693e-02, 2.60441613e-02, 1.59463171e-02,\n", + " 1.27008116e-02, 1.55250695e-02, 1.47959803e-02, 1.61639734e-02,\n", + " 1.64079152e-02, 5.80372976e-02, 4.74944128e-02, 4.75250292e-02,\n", + " 2.32183159e+00, 2.80865300e-02, 2.78076443e-02, 8.02219596e-02,\n", + " 9.02534287e-03, 5.92543719e-02, 5.17148021e-02, 3.96845366e-02,\n", + " 3.98006220e-02, 1.61583123e-01, 2.70137036e-02, 1.19062816e-02,\n", + " 1.18423618e-02, 4.60304774e-01, 3.02824616e-01, 2.84148869e-02,\n", + " 2.86301196e-02, 5.05538577e-01, 5.05019487e-01, 7.70048544e-02,\n", + " 2.84816302e-02, 1.43355263e-02, 7.57698631e-02, 7.58718365e-02,\n", + " 1.55583954e-01, 1.55754538e-01, 1.01054833e-01, 3.67160794e-02,\n", + " 2.70719193e-02, 1.27703962e-02, 3.97016600e-01, 1.48770099e-02,\n", + " 1.47554642e-02, 6.68787762e-02, 2.08138619e-02, 2.00875650e-02,\n", + " 2.01317701e-02, 7.90052723e-02, 1.96882220e-02, 2.99542467e-02,\n", + " 1.93528875e-02, 3.32003635e-02, 2.25904294e-01, 2.25783485e-01,\n", + " 1.39998624e-02, 1.34859595e-02, 2.63065048e-02, 1.43845795e-02,\n", + " 1.00032457e+00, 1.42964203e-02, 1.32487160e-02, 1.55253780e-02,\n", + " 2.10883637e-02, 1.42120842e-02, 1.51838236e-02, 1.42380259e-02,\n", + " 2.14302554e-02, 1.43170351e-02, 2.14342726e-02, 2.10312560e-02,\n", + " 1.41719928e-02, 1.41149727e-02, 1.81794134e-02, 9.46623047e-02,\n", + " 9.32812602e-02, 9.34229735e-02, 3.48520133e-02, 1.04476343e-02,\n", + " 5.27484755e-02, 1.27445779e-01, 1.27909945e-01, 1.27712191e-01,\n", + " 1.16982154e-01, 1.16720238e-01, 1.72196513e-02, 1.72594036e-02,\n", + " 5.49238320e-01, 8.27994455e-01, 3.63711021e-02, 3.80612849e-02,\n", + " 3.64567592e-02, 3.67093441e-02, 3.94651124e-02, 6.66919127e-01,\n", + " 6.04317578e+10, 2.60910986e-02, 3.36998252e-02, 1.93161386e-01,\n", + " 2.76921902e-02, 1.18755591e-01, 9.47532943e-02, 2.37114110e-02,\n", + " 1.97594094e-02, 1.90097520e-02, 1.87143349e-01, 1.69735080e-02,\n", + " 1.69543121e-02, 1.69883568e-02, 6.66408823e-02, 6.67836631e-02,\n", + " 6.74965121e-02, 5.78424847e-02, 5.78960591e-02, 5.88917791e-02,\n", + " 2.53043427e-02, 2.59867873e-02, 1.39816162e-02, 8.74358916e-02,\n", + " 5.96052950e-02, 3.63008803e-02, 2.48916758e-02, 1.12562710e-02,\n", + " 3.56003229e-02, 3.53116238e-02, 3.53211599e-02, 3.76235543e-02,\n", + " 5.99390668e-02, 3.49485769e-01, 9.84935079e-02, 9.60091766e-02,\n", + " 2.50161908e-01, 4.25016261e-02, 1.42500750e-02, 1.35140680e-02,\n", + " 1.27795523e-02, 1.37991845e-02, 1.46464880e-02, 1.17261029e-02,\n", + " 2.27711843e-02, 1.28582350e-02, 1.90349696e-02, 1.23501704e-02,\n", + " 1.18110330e-02, 1.40024463e-02, 1.94364608e-02, 2.07699815e-02,\n", + " 1.19190001e-02, 2.00312749e-02, 1.52443207e-02, 1.89072137e-02,\n", + " 1.16709907e-02, 1.27503908e-02, 1.19673922e-02, 1.27984544e-02,\n", + " 2.98808912e-02, 3.94043547e-02, 3.78844862e-02, 5.47246921e-02,\n", + " 2.40745014e-02, 2.58021898e-02, 8.50095915e-02, 2.42755252e-02,\n", + " 2.59785233e-02, 9.17099452e-03, 4.50628923e-02, 3.52214163e-02,\n", + " 1.00002431e+00, 4.03497444e-02, 3.00818352e-01, 1.64343774e-01,\n", + " 8.07069957e-02, 5.57020739e-02, 1.78935361e-02, 1.58185884e-02,\n", + " 2.29241215e-02, 2.94194212e-02, 3.33321799e-01, 1.71762602e-01,\n", + " 1.71548053e-01, 9.05164982e-03, 3.29983604e-02, 3.43974640e-02,\n", + " 3.49767120e-02, 3.54707258e-02, 3.62642267e-02, 3.64402675e-02,\n", + " 3.90867928e-02, 3.70576141e-02, 3.00248366e-02, 3.29686045e-02,\n", + " 2.18289379e-02, 2.24040078e-02, 2.19743722e-02, 2.21750642e-02,\n", + " 2.84545589e-02, 3.58205899e-02, 1.45997695e-02, 7.37984040e-01,\n", + " 7.38047986e-01, 3.36131565e-01, 2.84432916e-01, 2.84860350e-01,\n", + " 2.40262548e-02, 1.66128273e-02, 2.96266508e-02, 3.01873455e-02,\n", + " 1.21287122e-02, 1.65723621e-02, 1.42053291e-02, 1.13274821e-02,\n", + " 2.24744723e-02, 1.24989899e-01, 1.60165581e-01, 1.74294440e-02,\n", + " 3.36992295e-02, 4.30040488e-02, 4.50828365e-02, 8.17672446e-02,\n", + " 5.08194873e-02, 4.63590255e-02, 4.62908706e-02, 2.70423950e-02,\n", + " 2.17150724e-02, 2.08292534e-02, 1.53781184e-02, 2.07518658e-02,\n", + " 2.63991981e-02, 8.38570989e-02, 1.51219206e-02, 1.45052778e-02,\n", + " 1.24497195e-02, 1.37731019e-02, 9.60952995e-02, 1.19065356e-02,\n", + " 1.25283866e-02, 2.48209760e-02, 2.30612214e-02, 2.84032218e-02,\n", + " 2.85880449e-02, 2.82693108e-02, 4.07696964e-02, 4.07277066e-02,\n", + " 6.05188448e-02, 2.54689006e-02, 1.22174870e-01, 1.13723546e-01,\n", + " 5.39282829e-02, 2.32629373e-01, 1.51055872e-02, 1.10591105e-01,\n", + " 1.24069392e-02, 1.68292768e-02, 2.43317303e-01, 4.51452782e-02,\n", + " 4.57182373e-02, 2.69476467e-02, 2.51840476e-01, 2.85332085e-02,\n", + " 6.95222413e-02, 1.14155929e-01, 1.14499871e-01, 2.52412003e-02,\n", + " 2.75499046e-02, 2.92961890e-01, 2.34126467e-02, 9.35536866e-02,\n", + " 9.59660895e-02, 1.66742081e-01, 1.01422936e+00, 8.57167170e-02,\n", + " 4.57822158e-02, 2.45945672e-02, 9.02211903e-02, 2.14863052e-02,\n", + " 2.96223255e-02, 1.20345533e-01, 3.29629845e-02, 3.30504936e-02,\n", + " 1.82433972e-02, 1.97688554e-02, 1.85047618e-02, 2.09648689e-02,\n", + " 2.06765151e-02, 1.94285685e-02, 1.90760036e-02, 6.46498356e-02,\n", + " 2.59880087e+10, 2.59880087e+10, 1.15635257e-01, 1.15767788e-01,\n", + " 7.98848958e-02, 2.85556640e-02, 2.30800212e-02, 2.27064430e-02,\n", + " 2.26412933e-02, 1.26836637e-01, 1.31273973e-02, 7.92608619e-03,\n", + " 9.19101695e-03, 7.91554038e-03, 1.97895086e-02, 9.21010492e-03,\n", + " 6.24728953e+10, 6.73855739e-02, 2.44427094e-01, 2.93901310e-02,\n", + " 2.90899167e-02, 1.12201458e-01, 1.29599298e-01, 2.02777237e-02,\n", + " 2.08723387e-02, 2.38898461e-02, 1.04076508e-01, 4.20637860e-02,\n", + " 1.00036488e+00, 4.89022784e+11, 5.61071703e-02, 8.02830980e-02,\n", + " 2.29937177e-01, 1.80481028e-02, 1.81468668e-02, 6.35480446e-02,\n", + " 8.32049987e-02, 8.27792394e-02, 8.31563291e-02, 2.74379860e-02,\n", + " 2.85556289e-02, 5.01063808e-01, 2.08806041e-02, 2.52383980e-01,\n", + " 3.65634530e-02, 2.61607822e-01, 5.00271736e-01, 2.11269194e-02,\n", + " 2.05150840e-02, 2.06351597e-02, 7.83238659e-02, 6.40729205e-02,\n", + " 2.49572736e-02, 2.63598216e-01, 3.18388346e-02, 1.29182738e+00,\n", + " 5.06268811e-01, 5.00367285e-01, 1.55854407e-01, 8.67539036e-02,\n", + " 5.00039243e-01, 2.83902978e-01, 5.00003040e-01, 2.03285051e-02,\n", + " 1.98327684e-02, 2.13301255e-02, 1.97806848e-02, 2.02845566e-02,\n", + " 2.15869019e-02, 2.90721115e-02, 2.17215873e-02, 2.25779515e-02,\n", + " 2.28859461e-02, 2.04718288e-02, 2.06740274e-02, 2.13549206e-02,\n", + " 1.14853010e-01, 5.83450966e-02, 5.86860578e-02, 2.08138619e-02,\n", + " 2.95612257e-02, 2.96176510e-02, 7.11378162e-02, 1.62304796e-02,\n", + " 2.47530705e-02, 2.06151223e-02, 1.86643789e-02, 4.76258099e-02,\n", + " 4.79931628e-02, 1.03239191e-02, 1.75756392e-02, 1.89420862e-02,\n", + " 1.55713371e-02, 1.62934809e-02, 6.14037753e-02, 5.28503292e-02,\n", + " 1.42600459e+10, 3.34348638e-01, 4.60682063e-02, 4.45236771e-02,\n", + " 1.40583764e-02, 1.32133961e-02, 1.38587033e-02, 1.31084426e-02,\n", + " 1.01391732e-02, 1.66744397e-01, 2.26406837e-01, 1.19846206e-02,\n", + " 1.19477978e-02, 1.21069635e-02, 2.33854112e-02, 1.61033319e-02,\n", + " 3.47933067e-02, 3.24004767e-02, 3.36049838e-02, 4.83319918e-02,\n", + " 4.88051160e-02, 1.43261221e-01, 2.99275604e-02, 2.06911474e-01,\n", + " 2.48593586e-02, 5.34807881e-02, 2.12765592e-01, 2.65568395e-02,\n", + " 2.39112234e-02, 1.49668641e-02, 4.33836412e-02, 4.34213464e-02,\n", + " 3.84163250e-02, 1.04837347e-01, 1.36239928e-02, 4.46474516e-02,\n", + " 6.51191486e-01, 5.54606867e-02, 1.01293252e-02, 9.34600787e-03,\n", + " 1.25106162e-02, 4.68413773e-02, 2.19690713e-02, 2.39366005e-02,\n", + " 9.53095651e-03, 1.25628714e-02, 1.28794308e-02, 4.73486619e-02,\n", + " 4.79755498e-02, 2.68630761e-02, 2.91536330e-02, 2.17860938e-02,\n", + " 2.19619236e-02, 2.49145578e-02, 1.42600459e+10, 3.33645546e-01,\n", + " 3.33856105e-01, 2.40499473e-02, 2.36466120e-02, 1.59859587e-01,\n", + " 2.12683351e-01, 2.51584307e-02, 5.00077114e-01, 2.39671995e-02,\n", + " 3.30144211e-02, 8.79741031e-02, 1.57212102e-02, 1.02412363e-02,\n", + " 1.04667710e-01, 1.18027946e-01, 1.19088964e-01, 1.19326022e-01,\n", + " 1.35277998e-01, 1.35571455e-01, 1.35288891e-01, 3.81267066e-02,\n", + " 4.54171998e-02, 4.48601208e-02, 2.10525119e-02, 2.03944391e-02,\n", + " 2.00469867e-02, 5.14077545e-02, 1.00077145e+00, 2.52036651e-01,\n", + " 9.91375194e-02, 9.90044945e-02, 5.90921601e-02, 1.19443518e-02,\n", + " 1.21193013e-02, 1.68592810e-02, 1.29446375e-02, 4.77459720e-02,\n", + " 5.68468919e-02, 1.56066076e-02, 3.66155243e-02, 2.78871915e-01,\n", + " 3.44565117e-02, 3.21766780e-02, 1.47881300e-02, 4.91339144e-02,\n", + " 8.96558607e-02, 9.22060109e-02, 1.74562963e-02, 8.50767733e-02,\n", + " 3.45236903e-02, 1.01412338e+00, 1.28047802e-02, 1.26931429e-02,\n", + " 6.27932605e-02, 1.03216617e-02, 5.35191668e-02, 1.15205397e-01,\n", + " 1.16616817e-01, 1.17349910e-01, 5.90405344e-02, 5.89076863e-02,\n", + " 1.68795124e-02, 2.87137832e-02, 1.69519789e-02, 1.21922438e-02,\n", + " 1.28366371e-02, 1.39163533e-02])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if you want to predict the distances for X to center of the training set, you should write \n", + "Leverage(threshold='auto').fit(X_train, Y_train).predict_proba(X_test) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Lev_cv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The drawback of it is the absence of strict rules for choosing the threshold h*. As an alternative, an optimal threshold value h* can be found using an internal cross-validation procedure by maximizing some AD performance metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + } + ], + "source": [ + "AD_Lev_cv = Leverage(threshold='cv', score=score, \n", + " reg_model=est.best_estimator_).fit(X_train, Y_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, False, True, True, True,\n", + " True])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_Lev_cv[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Z-1NN" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This AD definition is based on the distance(s) between a current reaction and the closest training-set reaction(s). Usually, one nearest neighbour is considered (k=1). If the distance is not within the user-defined threshold, then the prediction is considered unreliable and the reaction is considered as X-outlier. The threshold value is commonly taken as Dc=Zσ+, where is the average and σ is the standard deviation of the Euclidean distances between nearest neighbours in the training set, Z is an empirical parameter to control the significance level, with the recommended value of 0.5." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from CIMtools.applicability_domain import SimilarityDistance\n", + "\n", + "AD_Z1NN = SimilarityDistance(threshold='auto').fit(X_train, Y_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_Z1NN[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1.02001345e-04, 5.53194200e-05, 2.90006968e-01, 1.60658619e-04,\n", + " 3.06629569e-04, 2.35221173e-01, 3.33928136e-01, 5.89367610e-05,\n", + " 1.60658619e-04, 1.45970950e-04, 4.15819323e-04, 1.68378995e-04,\n", + " 9.02001109e-01, 1.10650395e+00, 1.35528595e-01, 0.00000000e+00,\n", + " 9.68561444e+00, 0.00000000e+00, 8.08610280e-05, 6.21910250e+00,\n", + " 1.07160581e+00, 2.49572261e+00, 3.96237352e-01, 0.00000000e+00,\n", + " 0.00000000e+00, 1.92111473e+00, 1.92111473e+00, 4.86333198e-05,\n", + " 1.12532006e-06, 9.50442516e-05, 4.37599359e-05, 9.28872322e-05,\n", + " 0.00000000e+00, 4.56781536e-01, 3.31064411e+00, 1.05339199e-04,\n", + " 9.62943201e-08, 0.00000000e+00, 2.88661719e+00, 0.00000000e+00,\n", + " 0.00000000e+00, 5.75867802e+00, 9.88196756e-05, 8.25195377e-05,\n", + " 4.71512744e-05, 5.01863557e-05, 9.28872322e-05, 3.38049547e-06,\n", + " 5.91917911e-05, 3.14835139e-05, 3.36203000e-05, 3.27651480e-05,\n", + " 1.12525909e-04, 9.88196756e-05, 9.88196756e-05, 2.04158874e-04,\n", + " 1.05339199e-04, 5.53194200e-05, 1.69992972e-04, 9.80445600e-05,\n", + " 5.35242092e-05, 5.18149897e-05, 5.83207881e+00, 1.45970950e-04,\n", + " 3.06975694e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", + " 1.68378995e-04, 3.20000000e-01, 0.00000000e+00, 0.00000000e+00,\n", + " 0.00000000e+00, 5.53194200e-05, 5.18149897e-05, 1.50634665e-04,\n", + " 4.59396954e-05, 1.10650396e+00, 0.00000000e+00, 5.53194200e-05,\n", + " 3.33928141e-01, 5.18149897e-05, 1.00000337e-01, 2.59478294e-04,\n", + " 1.50634665e-04, 9.57845942e-05, 1.02001345e-04, 5.01863557e-05,\n", + " 1.33944615e+00, 4.47367865e+00, 1.45970950e-04, 5.53194200e-05,\n", + " 1.16398280e-04, 1.16398280e-04, 9.88196756e-05, 5.90000131e-01,\n", + " 4.30394558e-05, 9.88196756e-05, 5.48110127e-05, 1.16693963e-04,\n", + " 9.28872322e-05, 1.05339199e-04, 1.12525909e-04, 1.05339199e-04,\n", + " 1.45970950e-04, 3.52909863e-01, 3.06975694e-04, 2.43080018e-04,\n", + " 5.61529622e-06, 1.02001345e-04, 5.89367610e-05, 5.18149897e-05,\n", + " 1.22572964e+01, 1.22572964e+01, 0.00000000e+00, 0.00000000e+00,\n", + " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.24676057e-04,\n", + " 0.00000000e+00, 5.72636348e+00, 1.91706908e-04, 5.91917911e-05,\n", + " 1.53617827e-04, 2.82912976e-04, 2.59478294e-04, 2.59478294e-04,\n", + " 1.40644166e+02, 1.40644166e+02, 1.05339199e-04, 1.64517062e-04,\n", + " 1.45970950e-04, 8.87992618e+01, 8.88037627e+01, 1.87865639e-04,\n", + " 4.30759794e-01, 1.60658619e-04, 2.26950772e-05, 1.32358163e+01,\n", + " 1.60658619e-04, 1.45970950e-04, 1.20474057e-04, 1.05339199e-04,\n", + " 9.28872322e-05, 1.92111473e+00, 5.18149897e-05, 1.29543338e-04,\n", + " 2.59478294e-04, 4.86587185e-05, 8.86433167e-05, 1.31907539e-04,\n", + " 1.50634665e-04, 5.35242092e-05, 1.05339199e-04, 1.05339199e-04,\n", + " 1.88418175e+00, 1.55525555e-04, 7.21357198e+00, 9.19812655e+00,\n", + " 3.52909863e-01, 1.29295148e-04, 1.20474057e-04, 0.00000000e+00,\n", + " 0.00000000e+00, 3.45506029e-05, 1.68378995e-04, 8.33690283e-05,\n", + " 5.89367610e-05, 6.25854963e-01, 1.50113135e-04, 8.58014819e-05,\n", + " 5.84100459e-05, 7.79702148e-05, 2.17865108e-04, 1.49594450e+00,\n", + " 4.66222519e+00, 1.64517062e-04, 1.94088701e-04, 1.21361156e-04,\n", + " 5.08280495e-01, 3.54341854e-05, 3.73989278e-05, 3.95317316e-05,\n", + " 6.63854480e-05, 1.25609549e-04, 1.37270655e-04, 1.25609549e-04,\n", + " 1.37270655e-04, 1.15373812e-04, 1.15373812e-04, 2.40983362e-04,\n", + " 9.28872322e-05, 1.15373812e-04, 1.25609549e-04, 1.15373812e-04,\n", + " 1.25609549e-04, 3.84431530e-05, 2.00003165e-02, 2.00000000e-02,\n", + " 2.00013572e-02, 1.50634665e-04, 3.06975694e-04, 6.31686249e+00,\n", + " 6.18784642e-01, 3.36203000e-05, 3.54341854e-05, 9.88196756e-05,\n", + " 6.18784664e-01, 6.18784642e-01, 3.95317316e-05, 5.35242092e-05,\n", + " 5.35242092e-05, 8.74734346e-05, 1.02001345e-04, 0.00000000e+00,\n", + " 1.54915725e+01, 1.54915725e+01, 5.08280495e-01, 5.90566073e+00,\n", + " 1.16921042e+00, 2.26336453e-04, 1.03600362e-04, 5.72064887e-05,\n", + " 1.08843629e-04, 4.30394558e-05, 5.72064887e-05, 6.23759450e+00,\n", + " 1.50113135e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", + " 0.00000000e+00, 5.35242092e-05, 1.60658619e-04, 3.88945605e-04,\n", + " 9.28872322e-05, 1.05339199e-04, 9.28872322e-05, 5.18149897e-05,\n", + " 5.53194200e-05, 2.01377179e+01, 0.00000000e+00, 1.02001345e-04,\n", + " 9.88196756e-05, 1.60658619e-04, 8.01601680e-07, 1.45970950e-04,\n", + " 3.63966786e-05, 4.95027713e+00, 6.48745998e+00, 5.91917911e-05,\n", + " 1.91706908e-04, 2.59478294e-04, 1.29325761e-04, 5.00006733e-02,\n", + " 2.61306334e-01, 5.35242092e-05, 2.34290029e-04, 9.28872322e-05,\n", + " 8.74734346e-05, 4.86333198e-05, 6.18784642e-01, 6.18784642e-01,\n", + " 1.92111473e+00, 3.96237315e-01, 4.27144004e-01, 8.74734346e-05,\n", + " 1.02001345e-04, 5.53194200e-05, 9.88196756e-05, 0.00000000e+00,\n", + " 2.30041106e+01, 1.52839000e-04, 1.60658619e-04, 5.35242092e-05,\n", + " 1.05339199e-04, 9.83276612e-05, 3.19422132e-05, 6.47073612e-05,\n", + " 2.00144918e+01, 4.36080561e+00, 0.00000000e+00, 0.00000000e+00,\n", + " 2.59478294e-04, 1.07672929e+01, 1.50634665e-04, 1.05339199e-04,\n", + " 2.04158874e-04, 5.28177173e-01, 1.02001345e-04, 0.00000000e+00,\n", + " 0.00000000e+00, 1.66093699e+01, 1.66093699e+01, 0.00000000e+00,\n", + " 0.00000000e+00, 1.69992972e-04, 1.60658619e-04, 8.25195377e-05,\n", + " 9.40030706e+00, 1.45970950e-04, 1.10327338e+01, 5.53194200e-05,\n", + " 9.88196756e-05, 1.12525909e-04, 0.00000000e+00, 2.17865108e-04,\n", + " 1.45970950e-04, 9.88196756e-05, 1.05339199e-04, 2.00000000e-02,\n", + " 1.60658619e-04, 9.88196756e-05, 5.91917911e-05, 2.59478294e-04,\n", + " 1.16398280e-04, 1.16398280e-04, 5.53194200e-05, 1.16398280e-04,\n", + " 5.53194200e-05, 5.91917911e-05, 5.53194200e-05, 1.91706908e-04,\n", + " 9.28872322e-05, 1.11059083e+00, 9.41875410e+00, 0.00000000e+00,\n", + " 1.60658619e-04, 3.88945605e-04, 1.02001345e-04, 1.92111473e+00,\n", + " 1.00966693e-04, 3.06975694e-04, 6.18784642e-01, 4.57359578e-05,\n", + " 1.80187346e-07, 4.86333198e-05, 2.52189229e+00, 2.71118143e+00,\n", + " 2.71118142e+00, 7.21357198e+00, 3.84593092e-05, 1.12525909e-04,\n", + " 0.00000000e+00, 0.00000000e+00, 1.60658619e-04, 1.20000000e-01,\n", + " 1.20000320e-01, 5.18149897e-05, 5.53194200e-05, 1.20000000e-01,\n", + " 1.20000108e-01, 2.77056899e-04, 1.23756930e+00, 9.49969519e+00,\n", + " 3.06629569e-04, 1.45970950e-04, 1.13933237e+01, 8.74734346e-05,\n", + " 8.25195377e-05, 9.28939755e-07, 0.00000000e+00, 5.35242092e-05,\n", + " 5.01863557e-05, 0.00000000e+00, 6.03336673e+00, 4.46751675e-01,\n", + " 1.99085409e-01, 1.92437702e-04, 6.12822658e-05, 5.01863557e-05,\n", + " 1.10000000e-01, 8.97545793e-06, 1.16398280e-04, 1.08843629e-04,\n", + " 1.60658619e-04, 1.11059083e+00, 3.75876574e-04, 1.00000049e-01,\n", + " 1.05339199e-04, 0.00000000e+00, 4.86333198e-05, 2.57171149e-01,\n", + " 1.62292945e-01, 4.27144004e-01, 5.80719381e-01, 6.18784642e-01,\n", + " 4.66222519e+00, 3.54517428e-04, 2.59478294e-04, 8.70158952e-05,\n", + " 9.88196756e-05, 1.02001345e-04, 1.02001345e-04, 1.11437409e-04,\n", + " 1.02001345e-04, 9.84076141e-05, 2.16845029e-04, 7.30148774e-05,\n", + " 7.25704872e+00, 3.65308092e-01, 1.46807704e-04, 5.53194200e-05,\n", + " 3.75847847e-04, 4.85384448e+00, 1.92437702e-04, 1.91706908e-04,\n", + " 1.08843629e-04, 1.08843629e-04, 1.68290196e-04, 8.02055781e-07,\n", + " 1.10000000e-01, 0.00000000e+00, 5.18149897e-05, 1.98389293e+00,\n", + " 1.92437702e-04, 0.00000000e+00, 5.45916805e-04, 2.51777700e-01,\n", + " 1.00396783e-04, 5.53674092e-01, 0.00000000e+00, 0.00000000e+00,\n", + " 5.84100459e-05, 7.07125188e-05, 9.88196756e-05, 1.91706908e-04,\n", + " 1.60658619e-04, 9.28872322e-05, 3.95317316e-05, 3.73989278e-05,\n", + " 1.21431983e-04, 1.21431983e-04, 1.80360667e-04, 9.28872322e-05,\n", + " 0.00000000e+00, 9.00000000e-02, 1.50000000e-01, 0.00000000e+00,\n", + " 6.12822658e-05, 2.15583923e+01, 1.45970950e-04, 0.00000000e+00,\n", + " 3.11498974e-05, 1.04384434e+00, 1.09968949e-04, 3.45093731e-05,\n", + " 1.56051558e+01, 5.72064887e-05, 5.72064887e-05, 1.60658619e-04,\n", + " 9.28872322e-05, 9.88196756e-05, 1.04130714e+00, 2.47440328e-04,\n", + " 0.00000000e+00, 1.02007916e-06, 2.63681806e-04, 1.73196658e-04,\n", + " 9.28872322e-05, 5.53194200e-05, 8.74734346e-05, 1.59638540e-04,\n", + " 9.28872322e-05, 9.28872322e-05, 2.04081024e-06, 1.80360667e-04,\n", + " 1.29543338e-04, 9.30840769e-05, 9.28872322e-05, 9.88196756e-05,\n", + " 7.83224936e+00, 8.00000000e-02, 4.57359578e-05, 0.00000000e+00,\n", + " 5.53194200e-05, 1.45970950e-04, 1.10187924e+01, 3.27651480e-05,\n", + " 6.47073612e-05, 8.08454715e-01, 1.91706908e-04, 5.18149897e-05,\n", + " 5.53194200e-05, 5.80099096e+00, 2.02711005e+01, 4.86333198e-05,\n", + " 1.60658619e-04, 3.06629569e-04, 1.60658619e-04, 1.27524860e-04,\n", + " 2.90006898e-01, 0.00000000e+00, 5.01863557e-05, 0.00000000e+00,\n", + " 0.00000000e+00, 1.16398280e-04, 8.17204900e+00, 2.89421071e-05,\n", + " 9.88196756e-05, 0.00000000e+00, 9.69400227e+00, 5.91917911e-05,\n", + " 5.72064887e-05, 1.40000000e-01, 4.30759797e-01, 5.36613458e-01,\n", + " 1.14146884e-04, 9.30840769e-05, 2.16845029e-04, 1.91706908e-04,\n", + " 1.05407619e-04, 1.29543338e-04, 6.04165586e-05, 9.28872322e-05,\n", + " 2.25139549e-06, 5.72064887e-05, 1.23017786e+00, 8.85927518e-05,\n", + " 1.50634665e-04, 1.92111474e+00, 5.35242092e-05, 1.43001417e-04,\n", + " 6.00000000e-02, 5.53194200e-05, 9.88196756e-05, 5.72064887e-05,\n", + " 1.16398280e-04, 5.72064887e-05, 6.00000000e-02, 6.00000000e-02,\n", + " 2.00000765e-02, 5.53194200e-05, 1.20000000e-01, 2.59478294e-04,\n", + " 1.18011821e+00, 1.18011824e+00, 1.02001345e-04, 5.18149897e-05,\n", + " 4.63565478e+00, 4.86333198e-05, 3.84593092e-05, 8.49423376e-05,\n", + " 4.30759824e-01, 4.30759794e-01, 3.73989278e-05, 3.54341854e-05,\n", + " 1.95427528e-04, 3.53477320e+01, 1.05339199e-04, 7.07125188e-05,\n", + " 1.05339199e-04, 9.88196756e-05, 7.07125188e-05, 1.50634665e-04,\n", + " 8.18264933e+01, 5.53194200e-05, 0.00000000e+00, 0.00000000e+00,\n", + " 5.53194200e-05, 7.07096177e-01, 0.00000000e+00, 0.00000000e+00,\n", + " 1.92437702e-04, 4.14650695e-04, 6.56350171e+00, 1.41297571e-05,\n", + " 1.86223349e-05, 1.92437702e-04, 6.00005857e-02, 6.00005857e-02,\n", + " 5.63416166e-06, 7.20896577e-05, 3.56865296e-05, 1.45970950e-04,\n", + " 8.67753640e-05, 9.01268477e-05, 2.00000765e-02, 4.28605424e+00,\n", + " 5.80719381e-01, 3.05122926e-01, 1.51770491e-04, 0.00000000e+00,\n", + " 3.36203000e-05, 3.03867001e-05, 2.89421071e-05, 0.00000000e+00,\n", + " 1.05339199e-04, 9.51439244e+00, 0.00000000e+00, 3.21617508e-04,\n", + " 4.43834232e-05, 0.00000000e+00, 1.16398280e-04, 1.16398280e-04,\n", + " 9.88196756e-05, 1.16398280e-04, 1.60658619e-04, 9.88196756e-05,\n", + " 1.68911782e-04, 1.16398280e-04, 9.20423390e-06, 2.77056899e-04,\n", + " 1.60658619e-04, 1.16398280e-04, 1.05697623e-05, 8.50978098e-07,\n", + " 2.59478294e-04, 8.58247375e-05, 2.59478294e-04, 5.31417029e-05,\n", + " 1.60658619e-04, 9.88196756e-05, 1.16398280e-04, 9.88196756e-05,\n", + " 1.44330860e+00, 1.26047729e+00, 1.60658619e-04, 2.04158874e-04,\n", + " 6.46567105e-01, 1.45990300e-04, 8.30002349e-05, 0.00000000e+00,\n", + " 0.00000000e+00, 5.18149897e-05, 0.00000000e+00, 4.86333198e-05,\n", + " 0.00000000e+00, 1.52839000e-04, 9.96116128e+00, 1.63129247e+01,\n", + " 9.28872322e-05, 1.35958661e-04, 5.00000000e-02, 5.72064887e-05,\n", + " 6.00000000e-02, 1.20474057e-04, 5.35242092e-05, 1.69992972e-04,\n", + " 8.74734346e-05, 5.18149897e-05, 3.00995902e+00, 1.03600362e-04,\n", + " 1.03600362e-04, 2.63813360e-04, 1.03600362e-04, 1.03600362e-04,\n", + " 1.95609332e-01, 2.63813360e-04, 3.36203000e-05, 1.92437702e-04,\n", + " 4.49916689e-07, 4.48889951e-05, 9.88263311e-07, 9.57320384e-07,\n", + " 1.58980686e-04, 4.27674917e+00, 1.01942787e-06, 2.11719081e+01,\n", + " 2.11719081e+01, 1.60658619e-04, 1.30664928e+01, 1.30664928e+01,\n", + " 1.29543338e-04, 3.12053365e-05, 1.00201001e-04, 1.06859981e-04,\n", + " 1.06859981e-04, 1.00201001e-04, 1.06414289e-01, 2.25933018e-05,\n", + " 5.53674092e-01, 1.91706908e-04, 7.54615410e+00, 6.29335095e-05,\n", + " 6.18784642e-01, 0.00000000e+00, 0.00000000e+00, 9.88196756e-05,\n", + " 5.90566073e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", + " 4.18523262e-05, 2.72753376e-05, 0.00000000e+00, 1.05339199e-04,\n", + " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.57171149e-01,\n", + " 2.90006896e-01, 3.33928136e-01, 0.00000000e+00, 9.28872322e-05,\n", + " 2.62880204e-04, 0.00000000e+00, 5.91917911e-05, 1.40000000e-01,\n", + " 8.00000000e-02, 4.30900114e-05, 1.53375983e-04, 2.71172462e-04,\n", + " 0.00000000e+00, 1.02001345e-04, 5.90000131e-01, 1.12525909e-04,\n", + " 2.47440328e-04, 1.04384435e+00, 2.45393007e-05, 7.07096195e-01,\n", + " 9.56728257e-07, 1.46915066e+00, 7.88147012e+00, 1.45879896e-04,\n", + " 1.34826522e-04, 3.27651480e-05, 1.36975811e-04, 4.30900114e-05,\n", + " 2.13225323e+00, 5.74559726e-05, 1.01839396e-04, 4.71512744e-05,\n", + " 1.40000000e-01, 1.04384434e+00, 3.36203000e-05, 2.59478294e-04,\n", + " 2.59478294e-04, 0.00000000e+00, 8.94079415e-01, 4.38976783e+00,\n", + " 1.02001345e-04, 5.53985559e-02, 1.62292945e-01, 1.06414285e-01,\n", + " 7.20763484e-02, 1.20474057e-04, 1.05339199e-04, 9.28872322e-05,\n", + " 1.59743129e-04, 9.56460457e-05, 1.59743129e-04, 1.20474057e-04,\n", + " 5.91917911e-05, 9.99025727e-05, 9.99025727e-05, 1.45970950e-04,\n", + " 3.16381558e+01, 3.16381558e+01, 3.95678547e-06, 4.71512744e-05,\n", + " 0.00000000e+00, 4.71512744e-05, 9.34218302e+00, 9.34218302e+00,\n", + " 9.34218302e+00, 0.00000000e+00, 9.88196756e-05, 5.25704326e-06,\n", + " 1.19227209e-04, 9.28872322e-05, 9.28872322e-05, 2.57171155e-01,\n", + " 2.97719192e+01, 0.00000000e+00, 1.91706908e-04, 3.95317316e-05,\n", + " 3.73989278e-05, 8.94079415e-01, 1.60658619e-04, 4.18523262e-05,\n", + " 3.73989278e-05, 1.05339199e-04, 8.25195377e-05, 0.00000000e+00,\n", + " 1.68378995e-04, 8.80039033e+01, 1.45970950e-04, 1.60658619e-04,\n", + " 1.05339199e-04, 1.09968949e-04, 1.68378995e-04, 7.08349218e+00,\n", + " 1.80360667e-04, 8.74734346e-05, 6.18784642e-01, 9.57913060e-07,\n", + " 5.53194200e-05, 2.26469468e-04, 5.91917911e-05, 6.83372836e-01,\n", + " 0.00000000e+00, 1.04384434e+00, 1.05339199e-04, 5.72064887e-05,\n", + " 2.17865108e-04, 4.09572015e-04, 9.88196756e-05, 1.45970950e-04,\n", + " 3.36203000e-05, 3.25099783e+01, 6.18784642e-01, 2.54854080e+01,\n", + " 1.68378995e-04, 1.16398280e-04, 5.03526784e+00, 1.75258684e-04,\n", + " 1.22477289e-04, 1.80330269e-04, 9.28872322e-05, 1.12525909e-04,\n", + " 1.05339199e-04, 1.90577414e-04, 5.72064887e-05, 1.90577414e-04,\n", + " 1.78047241e-05, 3.49955848e-01, 0.00000000e+00, 9.32916254e-07,\n", + " 3.77879636e-05, 5.91917911e-05, 1.90577414e-04, 2.12385643e-05,\n", + " 4.27260969e+00, 1.60658619e-04, 3.06629569e-04, 4.30759797e-01,\n", + " 3.45093731e-05, 6.81296730e-05, 8.00004393e-02, 1.06511470e-04,\n", + " 1.06511470e-04, 1.86027034e-05, 9.20503839e-07, 9.28872322e-05,\n", + " 1.05339199e-04, 5.18149897e-05, 0.00000000e+00, 9.28872322e-05,\n", + " 1.29543338e-04, 1.05339199e-04, 7.01921550e-05, 9.39008543e-01,\n", + " 2.17367172e+01, 1.33209392e-04, 1.02001345e-04, 0.00000000e+00,\n", + " 0.00000000e+00, 5.08280495e-01, 1.80360667e-04, 3.20000000e-01,\n", + " 2.26950772e-05, 4.73683244e-05, 7.30651430e+00, 5.35242092e-05,\n", + " 5.53194200e-05, 1.05339199e-04, 1.20474057e-04, 6.66098280e-05,\n", + " 1.62501573e+00, 1.62501569e+00, 1.12525909e-04, 9.43630630e-05,\n", + " 4.86333198e-05, 1.04384434e+00, 1.51770491e-04, 2.90006896e-01,\n", + " 5.01863557e-05, 1.50634665e-04, 2.24978205e-04, 5.53194200e-05,\n", + " 1.60658619e-04, 6.18784642e-01, 1.52116418e-04, 4.93354747e-05,\n", + " 0.00000000e+00, 0.00000000e+00, 5.35242092e-05, 0.00000000e+00,\n", + " 2.13376091e+01, 0.00000000e+00, 2.26035591e-04, 0.00000000e+00,\n", + " 9.88196756e-05, 5.84100459e-05, 8.03178143e-05, 1.45970950e-04,\n", + " 5.35242092e-05, 5.18149897e-05, 5.53194200e-05, 4.71512744e-05,\n", + " 4.30900114e-05, 9.28872322e-05, 1.05339199e-04, 5.35242092e-05,\n", + " 5.01863557e-05, 1.91706908e-04, 1.75604191e+01, 1.60658619e-04,\n", + " 1.69992972e-04, 3.52909879e-01, 3.52909863e-01, 0.00000000e+00,\n", + " 2.33315732e-04, 5.35242092e-05, 5.35242092e-05, 5.72064887e-05,\n", + " 3.54341854e-05, 1.74220240e+00, 1.92437702e-04, 1.46915066e+00,\n", + " 6.94504345e-05, 9.88196756e-05, 7.07125188e-05, 1.38380365e-04,\n", + " 6.19045233e-01, 6.19045233e-01, 6.19045240e-01, 1.69992972e-04,\n", + " 5.90566073e+00, 5.90566073e+00, 0.00000000e+00, 3.27651480e-05,\n", + " 3.63966786e-05, 0.00000000e+00, 2.47440328e-04, 1.69992972e-04,\n", + " 1.97785940e-04, 9.57845942e-05, 1.50634665e-04, 2.70000000e-01,\n", + " 2.70000000e-01, 2.59478294e-04, 2.59478294e-04, 3.54341854e-05,\n", + " 8.74734346e-05, 1.45970950e-04, 1.98085258e+00, 2.28141909e-04,\n", + " 5.53674092e-01, 9.70413046e-05, 1.09968949e-04, 1.05339199e-04,\n", + " 9.84261685e-05, 2.33315732e-04, 1.29543338e-04, 7.92642428e+00,\n", + " 1.91706908e-04, 8.94079415e-01, 5.53194200e-05, 5.18149897e-05,\n", + " 9.88196756e-05, 5.53194200e-05, 1.05339199e-04, 8.74734346e-05,\n", + " 9.88196756e-05, 2.04158874e-04, 2.47513857e+00, 2.47513857e+00,\n", + " 1.20021259e-05, 1.29295148e-04, 6.12822658e-05, 1.69992972e-04,\n", + " 3.21763463e-04, 4.58739274e-04])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Returns the value of the nearest neighbor from the training set.\n", + "SimilarityDistance(threshold='auto').fit(X_train, Y_train).predict_proba(X_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Z-1NN_cv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An optimal threshold can be found using an internal cross-validation procedure by maximizing some AD performance metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + } + ], + "source": [ + "AD_Z1NN_cv = SimilarityDistance(score=score, threshold='cv', \n", + " reg_model=est.best_estimator_).fit(X_train, Y_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, False, True, False, False, False, True, True,\n", + " True])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_Z1NN_cv[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Two-Class X-inlier/Y-outlier Classifier" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, a binary classifier learns to distinguish Y-inliers from Y-outliers. First, QRPR models are built to predict quantitative characteristics of chemical reactions. Chemical reactions with higher prediction error estimated in cross-validation (more than 3×RMSE) are labelled as Y-outliers, while the remaining reactions are labelled as Y-inliers. After that, a binary classification model is trained to discriminate between them and provide a confidence score that a given reaction is a Y-inlier for the corresponding QRPR model. Although this method seems quite straightforward, we have not found its application in literature. Unfortunately, this method cannot be applied if there are no or too few Y-outliers. In this study, Random Forest Classifier implemented in scikit-learn library was used for building the binary classification model. The method requires setting the values of two hyperparameters: max_features (the values of features selected upon tree branching) and probability threshold P. If the predicted probability of belonging to the X-inliers is greater than P, the prediction of reaction characteristics by the QRPR model for it is considered reliable (within AD). Other hyperparameters of the Random Forest Classifier was set to defaults, except the number of decision trees in Random Forest Classifier which was set to 500." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 2.5min finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 1 candidates, totalling 5 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 4.6min finished\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from sklearn.model_selection import cross_val_predict\n", + "from CIMtools.applicability_domain import TwoClassClassifiers\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "\n", + "Y_predicted = cross_val_predict(RandomForestRegressor(random_state=1, n_estimators=500, \n", + " max_features=est.best_params_['max_features']), \n", + " X_train, Y_train, cv=kf, verbose=1, n_jobs=1)\n", + "Y_pr_ts = np.column_stack((Y_predicted, Y_train))\n", + "\n", + "Y_R_int = abs(Y_pr_ts[:, 0] - Y_pr_ts[:, 1]) <= 3 * np.sqrt(mean_squared_error(Y_pr_ts[:, 1], Y_pr_ts[:, 0]))\n", + "best_model_clf = GridSearchCV(RandomForestClassifier(n_estimators=500, random_state=1),\n", + " {'max_features': [None]},\n", + " cv=kf, verbose=1, n_jobs=1).fit(X_train, Y_R_int) \n", + "# [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 'auto', 'log2', 'sqrt', None]\n", + "AD_2CC = TwoClassClassifiers(threshold='cv', score=score, reg_model=est.best_estimator_, \n", + " clf_model=best_model_clf.best_estimator_).fit(X_train, Y_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, False, True, True,\n", + " True])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_2CC[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1-SVM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The one-class Support Vector Method reveals highly populated zones in descriptor space by maximizing the distance between a separating hyperplane and the origin in the feature space implicitly defined by some Mercers’ kernel. The decision function of such model returns (+1) for the reactions which fall into highly populated zones (within AD, i.e. X-inliers) and (−1) - for the reactions outside of AD (X-outliers). 1-SVM models were built in this study using the scikit-learn library. The method requires the fitting of two hyperparameters: nu (which defines the upper bound percentage of errors and lower bound percentage of support vectors) and gamma (parameter of RBF kernel which is used), the optimal values of which can be found in cross-validation" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 2 candidates, totalling 10 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 40 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 8 out of 10 | elapsed: 4.1s remaining: 1.0s\n", + "[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 4.2s finished\n" + ] + } + ], + "source": [ + "from sklearn.metrics import make_scorer\n", + "from sklearn.svm import OneClassSVM\n", + "\n", + "if score == 'ba_ad':\n", + " scorer_for_svm = make_scorer(balanced_accuracy_score_with_ad, greater_is_better=True)\n", + "else:\n", + " scorer_for_svm = make_scorer(rmse_score_with_ad, greater_is_better=True)\n", + "AD_SVM = GridSearchCV(OneClassSVM(), {'nu': [0.001, 0.005],\n", + " 'gamma': [1e-6]},\n", + " cv=kf, verbose=1, scoring=scorer_for_svm, n_jobs=-1).fit(X_train, Y_pr_ts).predict(X_test)\n", + "# [0.001, 0.005, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]\n", + "# [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100, 1000, 10000]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_SVM[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ML-dependent applicability domain definition approaches" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### RFR_VAR" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The variance of predictions made by an ensemble of QSAR/QSPR models is often used as a score for determining their AD. In this study, we extend this to QRPR models. If individual models trained on randomly selected subsets of the initial training set predict close property values than the averaged prediction made by their ensemble is considered reliable. Here, we consider a chemical reaction to be within AD (inlier) if the variance of property values predicted by the ensemble of models is less than a given threshold σ. The optimal value for σ can be found using internal cross-validation procedure by maximizing an AD quality metrics (see below). One of the approaches to estimate the prediction variance needed for this purpose is to build a QRPR model on the whole training set using the Random Forest Regression (RFR) machine learning method, which provides the mean (which is considered as a predicted value of the reaction property) and the variance of predictions (which is considered as a measure of prediction confidence) made by individual Random Trees individual models. This approach is denoted hereafter as RFR_VAR. In this study, a modified version of the Random Forest Regression method (RFR, 500 trees) implemented in scikit-learn library was used" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.base import clone\n", + "from numpy import hstack, unique\n", + "\n", + "from CIMtools.metrics.applicability_domain_metrics import balanced_accuracy_score_with_ad, rmse_score_with_ad\n", + "\n", + "def threshold(X, y, model, score):\n", + " cv = KFold(n_splits=5, random_state=1, shuffle=True)\n", + " model_int = clone(model)\n", + "\n", + " threshold_value, score_value = 0, 0\n", + " Y_pred, Y_true, AD = [], [], []\n", + " for train_index, test_index in cv.split(X):\n", + " x_train = safe_indexing(X, train_index)\n", + " x_test = safe_indexing(X, test_index)\n", + " y_train = safe_indexing(y, train_index)\n", + " y_test = safe_indexing(y, test_index)\n", + " model_int.fit(x_train, y_train)\n", + "\n", + " AD.extend(model_int.predict_proba(x_test)) \n", + " Y_pred.append(model_int.predict(x_test))\n", + " Y_true.append(y_test)\n", + " AD_stack = hstack(AD)\n", + " AD_ = unique(AD_stack)\n", + " for z in AD_:\n", + " AD_new = AD_stack <= z\n", + " if score == 'ba_ad':\n", + " val = balanced_accuracy_score_with_ad(Y_true=hstack(Y_true), Y_pred=hstack(Y_pred), AD=AD_new)\n", + " elif score == 'rmse_ad':\n", + " val = rmse_score_with_ad(Y_true=hstack(Y_true), Y_pred=hstack(Y_pred), AD=AD_new)\n", + " if val >= score_value:\n", + " score_value = val\n", + " threshold_value = z\n", + " return threshold_value, score_value" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.ensemble.forest module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.ensemble. Anything that cannot be imported from sklearn.ensemble is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/externals/joblib/__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.ensemble.base module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.ensemble. Anything that cannot be imported from sklearn.ensemble is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.tree.tree module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.tree. Anything that cannot be imported from sklearn.tree is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from random_forest_variance import RandomForestRegressor2" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + } + ], + "source": [ + "est_var = RandomForestRegressor2(random_state=1, n_estimators=500,\n", + " max_features=est.best_params_['max_features']).fit(X_train, Y_train)\n", + "AD_est_var_values = est_var.predict_proba(X_test)\n", + "min_h_param_RFR_VAR = threshold(X=X_train, y=Y_train, model=est_var, score=score) # для нахождения отсечки\n", + "AD_RFR_VAR = AD_est_var_values <= min_h_param_RFR_VAR[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, False, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_RFR_VAR[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### GPR-AD" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Gaussian Process Regression (GPR)** assumes that the joint distribution of a real-valued property of chemical reactions and their descriptors is multivariate normal (Gaussian) with the elements of its covariance matrix computed by means of special covariance functions (kernels). For every reaction, a GPR model produces using the Bayes’ theorem a posterior conditional distribution (so-called prediction density) of the reaction property given the vector of reaction descriptors. The prediction density has normal (Gaussian) distribution with the mean corresponding to predicted value of the property and the variance corresponding to prediction confidence. If the variance is greater than a predefined threshold σ, the chemical reaction is considered as X-outlier (out of AD). This AD definition method is denoted as GPR-AD. The method requires adjustment of three hyperparameter - alpha which stands for the noise level (also acts as regularization of the model), the parameter gamma of the RBF kernel which represents the covariance function and variance threshold σ. The optimal values of hyperparameters are determined using internal cross-validation. Other hyperparameters of Gaussian Processes are set by default." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n", + "/home/assima/env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", + " warnings.warn(msg, category=FutureWarning)\n" + ] + } + ], + "source": [ + "from CIMtools.applicability_domain import GPR_AD\n", + "AD_GPR = GPR_AD(threshold='cv', score='ba_ad', \n", + " gpr_model=gpr_grid.best_estimator_).fit(X_train, Y_train).predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, True, True, True, True, True, True, True, True,\n", + " True])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "AD_GPR[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zero Models and \"Perfect Model\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Zero models can be suggested to compare different AD definitions with some simple rules. Zero models include “optimistic” AD definitions assuming that all reactions are within AD (denoted as OZ), while the “pessimistic” AD definition assumes that all reactions are out of AD (denoted as PZ). In addition to zero models, the “perfect” AD definition can be proposed for comparison. The “Perfect AD model” assumes that all X-inliers are Y-inliers (i.e. for all reactions within AD absolute property prediction error is lower than 3×RMSE), and all X-outliers are Y-outliers (i.e. for all reactions outside AD absolute property prediction error is higher than 3×RMSE)." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "AD_ZeroModel_One = np.ones(X_test.shape[0]) # all reactions are in AD\n", + "AD_ZeroModel_Zero = np.zeros(X_test.shape[0]) # all reactions are not AD\n", + "AD_Perfect = abs(Y_pred - Y_test) <= 3 * np.sqrt(-est.best_score_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And so we go through the remaining 4 folds. \n", + "Then we collect all the results and calculate the characteristics..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# def assembly_of_results(name, AD_list, num, AD, RTC1=None):\n", + "# AD_list[num].extend(AD)\n", + "# if RTC1 is not None:\n", + "# AD_list[num+1].extend(np.logical_and(AD, RTC1))\n", + "# print('{}_is_done!'.format(name))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from CIMtools.applicability_domain import ReactionTypeControl\n", + "# RTC_1 = ReactionTypeControl(env=1).fit(reactions_train).predict(reactions_test)\n", + "# assembly_of_results('Reaction Type Control with R=1', AD_rfr, 0, RTC_1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "position": { + "height": "399px", + "left": "1558px", + "right": "20px", + "top": "125px", + "width": "350px" + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..44c1001 --- /dev/null +++ b/environment.yml @@ -0,0 +1,14 @@ +channels: + - conda-forge +dependencies: + - python>=3.7.0,<3.8 + - pandoc + - ipython + - pip + - pip: + - wheel + - Flask-Sphinx-Themes + - numpydoc + - nbsphinx + - git+https://github.com/cimm-kzn/CGRtools.git@master#egg=CGRtools[mrv] + - . diff --git a/readthedocs.yml b/readthedocs.yml index 71aa8ab..2c25c00 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,14 +1,7 @@ version: 2 -python: - version: 3.7 - install: - - requirements: requirements.txt - - method: pip - path: . - -build: - image: latest +conda: + environment: environment.yml sphinx: builder: html diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index fa248b5..0000000 --- a/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -wheel -m2r -Flask-Sphinx-Themes -numpydoc -git+https://github.com/cimm-kzn/CGRtools.git@master#egg=CGRtools[mrv] \ No newline at end of file diff --git a/setup.py b/setup.py index befd09c..720b2a1 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def finalize_options(self): 'pyparsing>=2.2.0,<2.5', 'pyjnius>=1.3.0'], data_files=[('bin', fragmentor), ('lib', ['RDtool/rdtool.jar'])], zip_safe=False, - long_description=(Path(__file__).parent / 'README.md').open().read(), + long_description=(Path(__file__).parent / 'README.rst').open().read(), classifiers=['Environment :: Plugins', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)', @@ -70,5 +70,8 @@ def finalize_options(self): 'Topic :: Scientific/Engineering :: Information Analysis', 'Topic :: Software Development', 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules'] + 'Topic :: Software Development :: Libraries :: Python Modules'], + command_options={'build_sphinx': {'source_dir': ('setup.py', 'doc'), + 'build_dir': ('setup.py', 'build/doc'), + 'all_files': ('setup.py', True)}} )