-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Config driven detectors - part 1 #458
Config driven detectors - part 1 #458
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -94,6 +93,27 @@ def infer_threshold(self, X: np.ndarray) -> None: | |||
pass | |||
|
|||
|
|||
class DriftConfigMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short docstring about what this class is for would be useful (same stuff as in the PR comments really).
alibi_detect/utils/warnings.py
Outdated
if new in kwargs: | ||
raise ValueError(f"{func_name} received both the deprecated kwarg `{alias}` " | ||
f"and it's replacement `{new}`.") | ||
warnings.warn(f'`{alias}` is deprecated; use `{new}`.', DeprecationWarning, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question - why stacklevel 3?
Also, I don't think we fully arrived at a conclusion whether using UserWarning
over DeprecationWarning
would be better. UserWarning
is always shown and won't be turned off by Python optimizations, whereas it's not the case for DeprecationWarning
which is more intended for developers. That being said, the users of alibi
are developers so the point is perhaps not important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was convinced we'd decided on this, but can't find anything so may have imagined it. I don't have much of an opinion either way. @mauicv any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p.s. re stacklevel
, I'm asking myself the same question. Trying to remind myself...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jklaise just checked and with stacklevel != 3
the warning isn't raised. I assume it's swallowed otherwise since the warning is raised in a decorated init method of a class. Am honestly a little confused about this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this tbh. I think I'm comfortable with DeprecationWarning
as our users are developers!
------- | ||
The detector's configuration dictionary. | ||
""" | ||
cfg = super().get_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question - why super
here but in other places use attribute access via the mixin class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question for a few other detectors below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_config
is done at a number of levels. The mixin class does the top-level drift detector generic stuff i.e. adds version
and x_ref
, then most of the attributes are added in the more specific subclasses i.e. BaseUnivariateDrift.get_config()
. Then any final detector-specific bits are added in the detector-specific subclass method. There is nothing specific for CVMDrift
, but as an example, FETDrift
adds in the alternative
field:
alibi-detect/alibi_detect/cd/fet.py
Lines 119 to 135 in 9eb5ace
def get_config(self) -> dict: | |
""" | |
Get the detector's configuration dictionary. | |
Returns | |
------- | |
The detector's configuration dictionary. | |
""" | |
cfg = super().get_config() | |
# Detector kwargs | |
kwargs = { | |
'alternative': self.alternative, | |
} | |
cfg.update(kwargs) | |
return cfg |
This can't be added in BaseUnivariateDrift
since not all univariate drift detectors have an alternative
kwarg.
P.s. The use of this of this mixin is probably a good argument for having a drift detector base class (in between BaseDetector
and BaseUnivariateDrift
etc). However, I didn't want to change the core of the code too much in this PR...
alibi_detect/cd/mmd.py
Outdated
@@ -92,6 +100,7 @@ def __init__( | |||
else: | |||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | |||
self.meta = self._detector.meta | |||
self._detector.backend = backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this info is needed by the backend class so it can call it's own specific get_config
method and save that info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't actually used at all right now, but I will amend as discussed in next comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've got rid of this, as the backend
is actually already stored in self.meta['backend']
. I'm then accessing this in the base classes to reduce duplication a little.
alibi_detect/cd/pytorch/lsdd.py
Outdated
cfg = super().get_config() | ||
|
||
# backend | ||
cfg.update({'backend': 'pytorch'}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the backend
kwarg was set as an attribute of the backend class by the agnostic class, hence wouldn't need to be hard-coded here? (same comment applies to mmd
etc. below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wasn't actually the original intention (the self.backend
was intended for use at save time, in the later PR).
However, this is actually a nice idea, since we can then move the cfg.update({'backend': self.backend})
up to the base class.
------- | ||
The detector's configuration dictionary. | ||
""" | ||
raise RuntimeError("get_config not yet implemented for SpotTheDiffDrift with pytorch backend.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other places have NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, happy to merge if the questions are answered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
# optionally already preprocess reference data | ||
self.p_val = p_val | ||
if preprocess_x_ref and isinstance(preprocess_fn, Callable): # type: ignore[arg-type] | ||
# x_ref preprocessing logic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the reason you didn't factor this out into a Mixin or separate function? I guess it's not the same in every case base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you that it would certainly be nice to factor out. But, I didn't want to do as part of this PR as it's a change to the core of the detectors and not really related to the config functionality (it only looks as if I've added stuff above as have had to rework the logic a bit due to the kwarg changes).
Also think above is motivation for adding a drift detector base class i.e. BaseDrift(BaseDetector)
which BaseUnivariateDrift
would inherit from, rather than adding lots of Mixin's...
Changes to detector kwargs and addition of get_config methods.
This PR implements a number of final fixes and refactorings for the config driven detectors: 1. The `preprocess_at_init` and `preprocess_at_pred` logic implemented in #381 and #458 has been reworked. This turned out to have a problem in how it dealt with `update_x_ref`, since regardless of `x_ref_preprocessed`, we still need to update reference data within `.predict()` when `update_x_ref` is set. All offline drift detectors have been reworked to use the old logic (but with `preprocess_x_ref` renamed `preprocess_at_init`), with the addition that `self.x_ref_preprocessed` is also checked internally. 2. The previous `get_config` methods involved a lot of boilerplate to try to recover the original args/kwargs from detector attributes. The new approach calls a generic `_set_config()` with `__init__`, and then `self.config` is returned by `get_config`. This should significantly reduce the workload to add save/load to new detectors. To avoid memory overheads, large artefacts such as `x_ref` are not set at `__init__`, and instead are added within `get_config`. 3. Owing to the ease of implementation with the new `get_config` approach, save/load has been added for the model uncertainty and online detectors! 4. Kernels and `preprocess_fn`'s were previously resolved in `_load_detector_config`, which wasn't consistent with how other artefacts were resolved (it also caused added extra challenges). These are now resolved in `resolve_config` instead. Following this the `KernelConfigResolved` and `PreprocessConfigResolved` pydantic models have been removed (they could be added back but it would complicate `resolve_config`). 5. Fixing determinism in #496 has allowed us to compare original and loaded detector predictions in `test_saving.py`. This uncovered bugs with how kernels were saved and loaded. These have been fixed. 6. The readthedocs.yml has been fully updated to the V2 schema so that we can use Python 3.9 for building the docs. This is required as the `class NDArray(Generic[T], np.ndarray[Any, T])` in `utils._typing` causes an error with `autodoc` on older Python versions.
This is the first of a series of PR's for the config-driven detector functionality. The original PR (#389) has been split into a number of smaller PR's to aid the review process.
Note: This initial PR is not fully functioning, in the sense that the
alibi_detect.utils.saving
submodule still expects the oldpreprocess_x_ref
kwarg and attribute. I don't think it is worth reworkingsaving
with these updated kwarg's as later part-PR's will significantly rework thesaving
submodule.Summary of PR
This PR includes two main changes:
1. Modification of detector kwarg's
preprocess_x_ref
kwarg for all offline detectors is separated into two kwargs,x_ref_preprocessed
andpreprocess_at_init
, in order to draw a distinction between whether or not to preprocess x_ref (x_ref_preprocessed), and when to preprocess x_ref (preprocess_at_init). See WIP: Config driven detectors #389 (comment) and Preprocess poc #381.input_shape
kwarg. This was missing from some detectors, and has been added as is used when saving preprocessing models (see WIP: Config driven detectors #389 (comment)).2. Addition of
get_config
methodsThis method has been added to each detector. It takes a detector's attributes and returns a config dictionary. This config dictionary will form the basis of the later saving work, with artifacts in the config being saved, before the config dict is saved as a toml file.
The
get_config
methods are added following the usual hierarchical approach wrt to the base classes and backend specific subclasses. For example the generalizable bits of the config are defined in a mxin inalibi_detect/base.py
, the backend agnostic bits are defined inalibi_detect/cd/base.py
, and the backend specific bits are then defined in the backend subclasses e.g.MMDDriftTF
.Complexity arises here because as some kwarg's are conditional on others, and the inital kwarg's are not always added as attributes, therefore can be complicated to infer the original values of the kwarg's at
get_config
time (see #389 (comment)).