Skip to content

Commit

Permalink
Config driven detectors - part 4 (#497)
Browse files Browse the repository at this point in the history
This PR implements a number of final fixes and refactorings for the config driven detectors:

1. The `preprocess_at_init` and `preprocess_at_pred` logic implemented in #381 and #458 has been reworked. This turned out to have a problem in how it dealt with `update_x_ref`, since regardless of `x_ref_preprocessed`, we still need to update reference data within `.predict()` when `update_x_ref` is set. All offline drift detectors have been reworked to use the old logic (but with `preprocess_x_ref` renamed `preprocess_at_init`), with the addition that `self.x_ref_preprocessed` is also checked internally. 
2. The previous `get_config` methods involved a lot of boilerplate to try to recover the original args/kwargs from detector attributes. The new approach calls a generic `_set_config()` with `__init__`, and then `self.config` is returned by `get_config`. This should significantly reduce the workload to add save/load to new detectors. To avoid memory overheads, large artefacts such as `x_ref` are not set at `__init__`, and instead are added within `get_config`. 
3. Owing to the ease of implementation with the new `get_config` approach, save/load has been added for the model uncertainty and online detectors!
4. Kernels and `preprocess_fn`'s were previously resolved in `_load_detector_config`, which wasn't consistent with how other artefacts were resolved (it also caused added extra challenges). These are now resolved in `resolve_config` instead. Following this the `KernelConfigResolved` and `PreprocessConfigResolved` pydantic models have been removed (they could be added back but it would complicate `resolve_config`).
5. Fixing determinism in #496 has allowed us to compare original and loaded detector predictions in `test_saving.py`. This uncovered bugs with how kernels were saved and loaded. These have been fixed.
6. The readthedocs.yml has been fully updated to the V2 schema so that we can use Python 3.9 for building the docs. This is required as the `class NDArray(Generic[T], np.ndarray[Any, T])` in `utils._typing` causes an error with `autodoc` on older Python versions.
  • Loading branch information
ascillitoe authored May 25, 2022
1 parent a87153b commit a9b377f
Show file tree
Hide file tree
Showing 63 changed files with 1,830 additions and 1,430 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ jobs:
runs-on: ubuntu-18.04

container:
image: readthedocs/build:latest
image: readthedocs/build:7.0 # 7.0 to get Python 3.9
options: --user root

steps:
- uses: actions/checkout@v2
- name: Create a virtualenv to use for docs build
run: |
python3.8 -m virtualenv $HOME/docs
python3.9 -m virtualenv $HOME/docs
- name: Install dependencies
run: |
. $HOME/docs/bin/activate
Expand Down
14 changes: 8 additions & 6 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@
# Required
version: 2

# Set the version of Python and other tools you might need
build:
image: latest # Python 3.8 available on latest
os: ubuntu-20.04
tools:
python: "3.9"
apt_packages:
- pandoc

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: doc/source/conf.py
configuration: doc/source/conf.py

# Optionally build your docs in additional formats such as PDF
formats:
- pdf
- pdf

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
install:
- requirements: requirements/docs.txt
install:
- requirements: requirements/docs.txt
102 changes: 74 additions & 28 deletions alibi_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import copy
import json
import numpy as np
from typing import Dict, Any, Optional, Callable
from typing import Dict, Any, Optional
from alibi_detect.version import __version__, __config_spec__

DEFAULT_META = {
"name": None,
"detector_type": None, # online or offline
"data_type": None, # tabular, image or time-series
"version": None,
"version": None
} # type: Dict


Expand Down Expand Up @@ -93,40 +93,86 @@ def infer_threshold(self, X: np.ndarray) -> None:
pass


# "Large artefacts" - to save memory these are skipped in _set_config(), but added back in get_config()
# Note: The current implementation assumes the artefact is stored as a class attribute, and as a config field under
# the same name. Refactoring will be required if this assumption is to be broken.
LARGE_ARTEFACTS = ['x_ref', 'c_ref', 'preprocess_fn']


class DriftConfigMixin:
"""
A mixin class to be used by detector `get_config` methods. The `drift_config` method defines the initial
generic configuration dict for all detectors, which is then fully populated by a detector's get_config method(s).
A mixin class containing methods related to a drift detector's configuration dictionary.
"""
x_ref: np.ndarray
preprocess_fn: Optional[Callable] = None

def drift_config(self):
config: Optional[dict] = None

def get_config(self) -> dict: # TODO - move to BaseDetector once config save/load implemented for non-drift
"""
Get the detector's configuration dictionary.
Returns
-------
The detector's configuration dictionary.
"""
if self.config is not None:
# Get config (stored in top-level self)
cfg = self.config
# Get low-level nested detector (if needed)
detector = self._detector if hasattr(self, '_detector') else self # type: ignore[attr-defined]
detector = detector._detector if hasattr(detector, '_detector') else detector # type: ignore[attr-defined]
# Add large artefacts back to config
for key in LARGE_ARTEFACTS:
if key in cfg: # self.config is validated, therefore if a key is not in cfg, it isn't valid to insert
cfg[key] = getattr(detector, key)
# Set x_ref_preprocessed flag
preprocess_at_init = getattr(detector, 'preprocess_at_init', True) # If no preprocess_at_init, always true!
cfg['x_ref_preprocessed'] = preprocess_at_init and detector.preprocess_fn is not None
return cfg
else:
raise NotImplementedError('Getting a config (or saving via a config file) is not yet implemented for this'
'detector')

@classmethod
def from_config(cls, config: dict):
"""
Instantiate a drift detector from a fully resolved (and validated) config dictionary.
Parameters
----------
config
A config dictionary matching the schema's in :class:`~alibi_detect.saving.schemas`.
"""
# Check for exisiting version_warning. meta is pop'd as don't want to pass as arg/kwarg
version_warning = config.pop('meta', {}).pop('version_warning', False)
# Init detector
detector = cls(**config)
# Add version_warning
detector.meta['version_warning'] = version_warning # type: ignore[attr-defined]
detector.config['meta']['version_warning'] = version_warning
return detector

def _set_config(self, inputs): # TODO - move to BaseDetector once config save/load implemented for non-drift
# Set config metadata
name = self.__class__.__name__
# strip off any backend suffix
backends = ['TF', 'Torch', 'Sklearn']
for backend in backends:
if name.endswith(backend):
name = name[:-len(backend)]
# Init config dict
cfg: Dict[str, Any] = {'name': name}

# Add x_ref
cfg.update({'x_ref': self.x_ref})
# Init config dict
self.config: Dict[str, Any] = {
'name': name,
'meta': {
'version': __version__,
'config_spec': __config_spec__,
}
}

# Add preprocess_fn field
if self.preprocess_fn is not None:
cfg.update({'preprocess_fn': self.preprocess_fn})
# args and kwargs
pop_inputs = ['self', '__class__', '__len__', 'name', 'meta']
[inputs.pop(k, None) for k in pop_inputs]

# Populate meta dict and add to config
cfg_meta = {
'version': __version__,
'config_spec': __config_spec__,
'version_warning': self.meta.get('version_warning', False)
}
cfg.update({'meta': cfg_meta})
# Overwrite any large artefacts with None to save memory. They'll be added back by get_config()
for key in LARGE_ARTEFACTS:
if key in inputs:
inputs[key] = None

return cfg
self.config.update(inputs)


class NumpyEncoder(json.JSONEncoder):
Expand Down
Loading

0 comments on commit a9b377f

Please sign in to comment.